Using different websockets lib; bug fix stdin/stdout; cleanup
This commit is contained in:
		@@ -7,7 +7,7 @@ import threading
 | 
				
			|||||||
from libs.dto.lsp_messages import get_message_str
 | 
					from libs.dto.lsp_messages import get_message_str
 | 
				
			||||||
from libs.dto.lsp_message_structs import LSPResponseTypes, ClientRequest, ClientNotification
 | 
					from libs.dto.lsp_message_structs import LSPResponseTypes, ClientRequest, ClientNotification
 | 
				
			||||||
from .lsp_controller_stdin_stdout import LSPControllerSTDInSTDOut
 | 
					from .lsp_controller_stdin_stdout import LSPControllerSTDInSTDOut
 | 
				
			||||||
# from .lsp_controller_websocket import LSPControllerWebsocket
 | 
					from .lsp_controller_websocket import LSPControllerWebsocket
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -20,8 +20,8 @@ def _log_list():
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class LSPController(LSPControllerSTDInSTDOut):
 | 
					# class LSPController(LSPControllerSTDInSTDOut):
 | 
				
			||||||
# class LSPController(LSPControllerWebsocket):
 | 
					class LSPController(LSPControllerWebsocket):
 | 
				
			||||||
    def __init__(self):
 | 
					    def __init__(self):
 | 
				
			||||||
        super(LSPController).__init__()
 | 
					        super(LSPController).__init__()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -73,16 +73,29 @@ class LSPControllerEvents:
 | 
				
			|||||||
        params["textDocument"]["version"]    = data["version"]
 | 
					        params["textDocument"]["version"]    = data["version"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        contentChanges         = params["contentChanges"][0]
 | 
					        contentChanges         = params["contentChanges"][0]
 | 
				
			||||||
        start                  = contentChanges["range"]["start"]
 | 
					 | 
				
			||||||
        end                    = contentChanges["range"]["end"]
 | 
					 | 
				
			||||||
        contentChanges["text"] = data["text"]
 | 
					        contentChanges["text"] = data["text"]
 | 
				
			||||||
        start["line"]          = data["line"]
 | 
					 | 
				
			||||||
        start["character"]     = 0
 | 
					 | 
				
			||||||
        end["line"]            = data["line"]
 | 
					 | 
				
			||||||
        end["character"]       = data["column"]
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        GLib.idle_add( self.send_notification, method, params )
 | 
					        GLib.idle_add( self.send_notification, method, params )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # def _lsp_did_change(self, data: dict):
 | 
				
			||||||
 | 
					    #     method = data["method"]
 | 
				
			||||||
 | 
					    #     params = didchange_notification_range["params"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    #     params["textDocument"]["uri"]        = data["uri"]
 | 
				
			||||||
 | 
					    #     params["textDocument"]["languageId"] = data["language_id"]
 | 
				
			||||||
 | 
					    #     params["textDocument"]["version"]    = data["version"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    #     contentChanges         = params["contentChanges"][0]
 | 
				
			||||||
 | 
					    #     start                  = contentChanges["range"]["start"]
 | 
				
			||||||
 | 
					    #     end                    = contentChanges["range"]["end"]
 | 
				
			||||||
 | 
					    #     contentChanges["text"] = data["text"]
 | 
				
			||||||
 | 
					    #     start["line"]          = data["line"]
 | 
				
			||||||
 | 
					    #     start["character"]     = 0
 | 
				
			||||||
 | 
					    #     end["line"]            = data["line"]
 | 
				
			||||||
 | 
					    #     end["character"]       = data["column"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    #     GLib.idle_add( self.send_notification, method, params )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _lsp_definition(self, data: dict):
 | 
					    def _lsp_definition(self, data: dict):
 | 
				
			||||||
        method = data["method"]
 | 
					        method = data["method"]
 | 
				
			||||||
        params = definition_request["params"]
 | 
					        params = definition_request["params"]
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -114,4 +114,4 @@ class LSPControllerSTDInSTDOut(LSPControllerBase):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
            if not lsp_response: return
 | 
					            if not lsp_response: return
 | 
				
			||||||
            GLib.idle_add(self.handle_lsp_response, lsp_response)
 | 
					            GLib.idle_add(self.handle_lsp_response, lsp_response)
 | 
				
			||||||
            GLib.idle_add(self._monitor_lsp_response)
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,17 +1,16 @@
 | 
				
			|||||||
# Python imports
 | 
					# Python imports
 | 
				
			||||||
import subprocess
 | 
					import subprocess
 | 
				
			||||||
import asyncio
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Lib imports
 | 
					# Lib imports
 | 
				
			||||||
from gi.repository import GLib
 | 
					from gi.repository import GLib
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Application imports
 | 
					# Application imports
 | 
				
			||||||
from libs import websockets
 | 
					# from libs import websockets
 | 
				
			||||||
 | 
					from libs.websocket_client import WebsocketClient
 | 
				
			||||||
from libs.dto.lsp_messages import LEN_HEADER, TYPE_HEADER, get_message_str, get_message_obj
 | 
					from libs.dto.lsp_messages import LEN_HEADER, TYPE_HEADER, get_message_str, get_message_obj
 | 
				
			||||||
from .lsp_controller_base import LSPControllerBase
 | 
					 | 
				
			||||||
from libs.dto.lsp_message_structs import \
 | 
					from libs.dto.lsp_message_structs import \
 | 
				
			||||||
    LSPResponseTypes, ClientRequest, ClientNotification, LSPResponseRequest, LSPResponseNotification
 | 
					    LSPResponseTypes, ClientRequest, ClientNotification, LSPResponseRequest, LSPResponseNotification
 | 
				
			||||||
 | 
					from .lsp_controller_base import LSPControllerBase
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -24,16 +23,7 @@ class LSPControllerWebsocket(LSPControllerBase):
 | 
				
			|||||||
        message      = f"Content-Length: {message_size}\r\n\r\n{message_str}"
 | 
					        message      = f"Content-Length: {message_size}\r\n\r\n{message_str}"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.log_list.add_log_entry("Client", data)
 | 
					        self.log_list.add_log_entry("Client", data)
 | 
				
			||||||
 | 
					        self.ws_client.send(message_str)
 | 
				
			||||||
        uri = "ws://localhost:4114"
 | 
					 | 
				
			||||||
        async def do_message(message_str):
 | 
					 | 
				
			||||||
            async with websockets.connect(uri) as websocket:
 | 
					 | 
				
			||||||
                await websocket.send(message_str)
 | 
					 | 
				
			||||||
                response = await websocket.recv()
 | 
					 | 
				
			||||||
                GLib.idle_add(self._monitor_lsp_response, response)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        asyncio.get_event_loop().run_until_complete( do_message(message_str) )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def start_lsp(self):
 | 
					    def start_lsp(self):
 | 
				
			||||||
        if not self._start_command: return
 | 
					        if not self._start_command: return
 | 
				
			||||||
@@ -57,6 +47,9 @@ class LSPControllerWebsocket(LSPControllerBase):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
            return
 | 
					            return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.ws_client = WebsocketClient()
 | 
				
			||||||
 | 
					        self.ws_client.set_callback(self._monitor_lsp_response)
 | 
				
			||||||
 | 
					        self.ws_client.start_client()
 | 
				
			||||||
        return self.lsp_process.pid
 | 
					        return self.lsp_process.pid
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def stop_lsp(self):
 | 
					    def stop_lsp(self):
 | 
				
			||||||
@@ -64,6 +57,7 @@ class LSPControllerWebsocket(LSPControllerBase):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        self._lsp_pid    = -1
 | 
					        self._lsp_pid    = -1
 | 
				
			||||||
        self._message_id = 0
 | 
					        self._message_id = 0
 | 
				
			||||||
 | 
					        self.ws_client.close_client()
 | 
				
			||||||
        self.lsp_process.terminate()
 | 
					        self.lsp_process.terminate()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _monitor_lsp_response(self, data: None or {}):
 | 
					    def _monitor_lsp_response(self, data: None or {}):
 | 
				
			||||||
@@ -82,16 +76,5 @@ class LSPControllerWebsocket(LSPControllerBase):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        response_id  = -1
 | 
					        response_id  = -1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        print()
 | 
					 | 
				
			||||||
        print()
 | 
					 | 
				
			||||||
        print()
 | 
					 | 
				
			||||||
        print()
 | 
					 | 
				
			||||||
        print()
 | 
					 | 
				
			||||||
        print(f"Response: {lsp_response}")
 | 
					 | 
				
			||||||
        print()
 | 
					 | 
				
			||||||
        print()
 | 
					 | 
				
			||||||
        print()
 | 
					 | 
				
			||||||
        print()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if not lsp_response: return
 | 
					        if not lsp_response: return
 | 
				
			||||||
        self.handle_lsp_response(lsp_response)
 | 
					        self.handle_lsp_response(lsp_response)
 | 
				
			||||||
@@ -74,6 +74,22 @@ didclose_notification = {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
didchange_notification = {
 | 
					didchange_notification = {
 | 
				
			||||||
 | 
						"method": "textDocument/didChange",
 | 
				
			||||||
 | 
						"params": {
 | 
				
			||||||
 | 
						    "textDocument": {
 | 
				
			||||||
 | 
					            "uri": "file://",
 | 
				
			||||||
 | 
					            "languageId": "python",
 | 
				
			||||||
 | 
					            "version": 1,
 | 
				
			||||||
 | 
						    },
 | 
				
			||||||
 | 
						    "contentChanges": [
 | 
				
			||||||
 | 
						        {
 | 
				
			||||||
 | 
						            "text": ""
 | 
				
			||||||
 | 
						        }
 | 
				
			||||||
 | 
						    ]
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					didchange_notification_range = {
 | 
				
			||||||
	"method": "textDocument/didChange",
 | 
						"method": "textDocument/didChange",
 | 
				
			||||||
	"params": {
 | 
						"params": {
 | 
				
			||||||
	    "textDocument": {
 | 
						    "textDocument": {
 | 
				
			||||||
@@ -92,7 +108,8 @@ didchange_notification = {
 | 
				
			|||||||
	                "end": {
 | 
						                "end": {
 | 
				
			||||||
	                    "line": 1,
 | 
						                    "line": 1,
 | 
				
			||||||
                        "character": 1,
 | 
					                        "character": 1,
 | 
				
			||||||
	                }
 | 
						                },
 | 
				
			||||||
 | 
						                "rangeLength": 0
 | 
				
			||||||
	            }
 | 
						            }
 | 
				
			||||||
	        }
 | 
						        }
 | 
				
			||||||
	    ]
 | 
						    ]
 | 
				
			||||||
@@ -100,7 +117,6 @@ didchange_notification = {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
# CompletionTriggerKind = 1 | 2 | 3;
 | 
					# CompletionTriggerKind = 1 | 2 | 3;
 | 
				
			||||||
# export const Invoked: 1 = 1;
 | 
					# export const Invoked: 1 = 1;
 | 
				
			||||||
# export const TriggerCharacter: 2 = 2;
 | 
					# export const TriggerCharacter: 2 = 2;
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										27
									
								
								src/libs/websocket/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										27
									
								
								src/libs/websocket/__init__.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,27 @@
 | 
				
			|||||||
 | 
					"""
 | 
				
			||||||
 | 
					__init__.py
 | 
				
			||||||
 | 
					websocket - WebSocket client library for Python
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Copyright 2024 engn33r
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					You may obtain a copy of the License at
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					limitations under the License.
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from ._abnf import *
 | 
				
			||||||
 | 
					from ._app import WebSocketApp as WebSocketApp, setReconnect as setReconnect
 | 
				
			||||||
 | 
					from ._core import *
 | 
				
			||||||
 | 
					from ._exceptions import *
 | 
				
			||||||
 | 
					from ._logging import *
 | 
				
			||||||
 | 
					from ._socket import *
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					__version__ = "1.8.0"
 | 
				
			||||||
							
								
								
									
										453
									
								
								src/libs/websocket/_abnf.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										453
									
								
								src/libs/websocket/_abnf.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,453 @@
 | 
				
			|||||||
 | 
					import array
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					import struct
 | 
				
			||||||
 | 
					import sys
 | 
				
			||||||
 | 
					from threading import Lock
 | 
				
			||||||
 | 
					from typing import Callable, Optional, Union
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from ._exceptions import WebSocketPayloadException, WebSocketProtocolException
 | 
				
			||||||
 | 
					from ._utils import validate_utf8
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					_abnf.py
 | 
				
			||||||
 | 
					websocket - WebSocket client library for Python
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Copyright 2024 engn33r
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					You may obtain a copy of the License at
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					limitations under the License.
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					try:
 | 
				
			||||||
 | 
					    # If wsaccel is available, use compiled routines to mask data.
 | 
				
			||||||
 | 
					    # wsaccel only provides around a 10% speed boost compared
 | 
				
			||||||
 | 
					    # to the websocket-client _mask() implementation.
 | 
				
			||||||
 | 
					    # Note that wsaccel is unmaintained.
 | 
				
			||||||
 | 
					    from wsaccel.xormask import XorMaskerSimple
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _mask(mask_value: array.array, data_value: array.array) -> bytes:
 | 
				
			||||||
 | 
					        mask_result: bytes = XorMaskerSimple(mask_value).process(data_value)
 | 
				
			||||||
 | 
					        return mask_result
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					except ImportError:
 | 
				
			||||||
 | 
					    # wsaccel is not available, use websocket-client _mask()
 | 
				
			||||||
 | 
					    native_byteorder = sys.byteorder
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _mask(mask_value: array.array, data_value: array.array) -> bytes:
 | 
				
			||||||
 | 
					        datalen = len(data_value)
 | 
				
			||||||
 | 
					        int_data_value = int.from_bytes(data_value, native_byteorder)
 | 
				
			||||||
 | 
					        int_mask_value = int.from_bytes(
 | 
				
			||||||
 | 
					            mask_value * (datalen // 4) + mask_value[: datalen % 4], native_byteorder
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        return (int_data_value ^ int_mask_value).to_bytes(datalen, native_byteorder)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					__all__ = [
 | 
				
			||||||
 | 
					    "ABNF",
 | 
				
			||||||
 | 
					    "continuous_frame",
 | 
				
			||||||
 | 
					    "frame_buffer",
 | 
				
			||||||
 | 
					    "STATUS_NORMAL",
 | 
				
			||||||
 | 
					    "STATUS_GOING_AWAY",
 | 
				
			||||||
 | 
					    "STATUS_PROTOCOL_ERROR",
 | 
				
			||||||
 | 
					    "STATUS_UNSUPPORTED_DATA_TYPE",
 | 
				
			||||||
 | 
					    "STATUS_STATUS_NOT_AVAILABLE",
 | 
				
			||||||
 | 
					    "STATUS_ABNORMAL_CLOSED",
 | 
				
			||||||
 | 
					    "STATUS_INVALID_PAYLOAD",
 | 
				
			||||||
 | 
					    "STATUS_POLICY_VIOLATION",
 | 
				
			||||||
 | 
					    "STATUS_MESSAGE_TOO_BIG",
 | 
				
			||||||
 | 
					    "STATUS_INVALID_EXTENSION",
 | 
				
			||||||
 | 
					    "STATUS_UNEXPECTED_CONDITION",
 | 
				
			||||||
 | 
					    "STATUS_BAD_GATEWAY",
 | 
				
			||||||
 | 
					    "STATUS_TLS_HANDSHAKE_ERROR",
 | 
				
			||||||
 | 
					]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# closing frame status codes.
 | 
				
			||||||
 | 
					STATUS_NORMAL = 1000
 | 
				
			||||||
 | 
					STATUS_GOING_AWAY = 1001
 | 
				
			||||||
 | 
					STATUS_PROTOCOL_ERROR = 1002
 | 
				
			||||||
 | 
					STATUS_UNSUPPORTED_DATA_TYPE = 1003
 | 
				
			||||||
 | 
					STATUS_STATUS_NOT_AVAILABLE = 1005
 | 
				
			||||||
 | 
					STATUS_ABNORMAL_CLOSED = 1006
 | 
				
			||||||
 | 
					STATUS_INVALID_PAYLOAD = 1007
 | 
				
			||||||
 | 
					STATUS_POLICY_VIOLATION = 1008
 | 
				
			||||||
 | 
					STATUS_MESSAGE_TOO_BIG = 1009
 | 
				
			||||||
 | 
					STATUS_INVALID_EXTENSION = 1010
 | 
				
			||||||
 | 
					STATUS_UNEXPECTED_CONDITION = 1011
 | 
				
			||||||
 | 
					STATUS_SERVICE_RESTART = 1012
 | 
				
			||||||
 | 
					STATUS_TRY_AGAIN_LATER = 1013
 | 
				
			||||||
 | 
					STATUS_BAD_GATEWAY = 1014
 | 
				
			||||||
 | 
					STATUS_TLS_HANDSHAKE_ERROR = 1015
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					VALID_CLOSE_STATUS = (
 | 
				
			||||||
 | 
					    STATUS_NORMAL,
 | 
				
			||||||
 | 
					    STATUS_GOING_AWAY,
 | 
				
			||||||
 | 
					    STATUS_PROTOCOL_ERROR,
 | 
				
			||||||
 | 
					    STATUS_UNSUPPORTED_DATA_TYPE,
 | 
				
			||||||
 | 
					    STATUS_INVALID_PAYLOAD,
 | 
				
			||||||
 | 
					    STATUS_POLICY_VIOLATION,
 | 
				
			||||||
 | 
					    STATUS_MESSAGE_TOO_BIG,
 | 
				
			||||||
 | 
					    STATUS_INVALID_EXTENSION,
 | 
				
			||||||
 | 
					    STATUS_UNEXPECTED_CONDITION,
 | 
				
			||||||
 | 
					    STATUS_SERVICE_RESTART,
 | 
				
			||||||
 | 
					    STATUS_TRY_AGAIN_LATER,
 | 
				
			||||||
 | 
					    STATUS_BAD_GATEWAY,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ABNF:
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    ABNF frame class.
 | 
				
			||||||
 | 
					    See http://tools.ietf.org/html/rfc5234
 | 
				
			||||||
 | 
					    and http://tools.ietf.org/html/rfc6455#section-5.2
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # operation code values.
 | 
				
			||||||
 | 
					    OPCODE_CONT = 0x0
 | 
				
			||||||
 | 
					    OPCODE_TEXT = 0x1
 | 
				
			||||||
 | 
					    OPCODE_BINARY = 0x2
 | 
				
			||||||
 | 
					    OPCODE_CLOSE = 0x8
 | 
				
			||||||
 | 
					    OPCODE_PING = 0x9
 | 
				
			||||||
 | 
					    OPCODE_PONG = 0xA
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # available operation code value tuple
 | 
				
			||||||
 | 
					    OPCODES = (
 | 
				
			||||||
 | 
					        OPCODE_CONT,
 | 
				
			||||||
 | 
					        OPCODE_TEXT,
 | 
				
			||||||
 | 
					        OPCODE_BINARY,
 | 
				
			||||||
 | 
					        OPCODE_CLOSE,
 | 
				
			||||||
 | 
					        OPCODE_PING,
 | 
				
			||||||
 | 
					        OPCODE_PONG,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # opcode human readable string
 | 
				
			||||||
 | 
					    OPCODE_MAP = {
 | 
				
			||||||
 | 
					        OPCODE_CONT: "cont",
 | 
				
			||||||
 | 
					        OPCODE_TEXT: "text",
 | 
				
			||||||
 | 
					        OPCODE_BINARY: "binary",
 | 
				
			||||||
 | 
					        OPCODE_CLOSE: "close",
 | 
				
			||||||
 | 
					        OPCODE_PING: "ping",
 | 
				
			||||||
 | 
					        OPCODE_PONG: "pong",
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # data length threshold.
 | 
				
			||||||
 | 
					    LENGTH_7 = 0x7E
 | 
				
			||||||
 | 
					    LENGTH_16 = 1 << 16
 | 
				
			||||||
 | 
					    LENGTH_63 = 1 << 63
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        fin: int = 0,
 | 
				
			||||||
 | 
					        rsv1: int = 0,
 | 
				
			||||||
 | 
					        rsv2: int = 0,
 | 
				
			||||||
 | 
					        rsv3: int = 0,
 | 
				
			||||||
 | 
					        opcode: int = OPCODE_TEXT,
 | 
				
			||||||
 | 
					        mask_value: int = 1,
 | 
				
			||||||
 | 
					        data: Union[str, bytes, None] = "",
 | 
				
			||||||
 | 
					    ) -> None:
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Constructor for ABNF. Please check RFC for arguments.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        self.fin = fin
 | 
				
			||||||
 | 
					        self.rsv1 = rsv1
 | 
				
			||||||
 | 
					        self.rsv2 = rsv2
 | 
				
			||||||
 | 
					        self.rsv3 = rsv3
 | 
				
			||||||
 | 
					        self.opcode = opcode
 | 
				
			||||||
 | 
					        self.mask_value = mask_value
 | 
				
			||||||
 | 
					        if data is None:
 | 
				
			||||||
 | 
					            data = ""
 | 
				
			||||||
 | 
					        self.data = data
 | 
				
			||||||
 | 
					        self.get_mask_key = os.urandom
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def validate(self, skip_utf8_validation: bool = False) -> None:
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Validate the ABNF frame.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Parameters
 | 
				
			||||||
 | 
					        ----------
 | 
				
			||||||
 | 
					        skip_utf8_validation: skip utf8 validation.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        if self.rsv1 or self.rsv2 or self.rsv3:
 | 
				
			||||||
 | 
					            raise WebSocketProtocolException("rsv is not implemented, yet")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.opcode not in ABNF.OPCODES:
 | 
				
			||||||
 | 
					            raise WebSocketProtocolException("Invalid opcode %r", self.opcode)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.opcode == ABNF.OPCODE_PING and not self.fin:
 | 
				
			||||||
 | 
					            raise WebSocketProtocolException("Invalid ping frame.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.opcode == ABNF.OPCODE_CLOSE:
 | 
				
			||||||
 | 
					            l = len(self.data)
 | 
				
			||||||
 | 
					            if not l:
 | 
				
			||||||
 | 
					                return
 | 
				
			||||||
 | 
					            if l == 1 or l >= 126:
 | 
				
			||||||
 | 
					                raise WebSocketProtocolException("Invalid close frame.")
 | 
				
			||||||
 | 
					            if l > 2 and not skip_utf8_validation and not validate_utf8(self.data[2:]):
 | 
				
			||||||
 | 
					                raise WebSocketProtocolException("Invalid close frame.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            code = 256 * int(self.data[0]) + int(self.data[1])
 | 
				
			||||||
 | 
					            if not self._is_valid_close_status(code):
 | 
				
			||||||
 | 
					                raise WebSocketProtocolException("Invalid close opcode %r", code)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @staticmethod
 | 
				
			||||||
 | 
					    def _is_valid_close_status(code: int) -> bool:
 | 
				
			||||||
 | 
					        return code in VALID_CLOSE_STATUS or (3000 <= code < 5000)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __str__(self) -> str:
 | 
				
			||||||
 | 
					        return f"fin={self.fin} opcode={self.opcode} data={self.data}"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @staticmethod
 | 
				
			||||||
 | 
					    def create_frame(data: Union[bytes, str], opcode: int, fin: int = 1) -> "ABNF":
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Create frame to send text, binary and other data.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Parameters
 | 
				
			||||||
 | 
					        ----------
 | 
				
			||||||
 | 
					        data: str
 | 
				
			||||||
 | 
					            data to send. This is string value(byte array).
 | 
				
			||||||
 | 
					            If opcode is OPCODE_TEXT and this value is unicode,
 | 
				
			||||||
 | 
					            data value is converted into unicode string, automatically.
 | 
				
			||||||
 | 
					        opcode: int
 | 
				
			||||||
 | 
					            operation code. please see OPCODE_MAP.
 | 
				
			||||||
 | 
					        fin: int
 | 
				
			||||||
 | 
					            fin flag. if set to 0, create continue fragmentation.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        if opcode == ABNF.OPCODE_TEXT and isinstance(data, str):
 | 
				
			||||||
 | 
					            data = data.encode("utf-8")
 | 
				
			||||||
 | 
					        # mask must be set if send data from client
 | 
				
			||||||
 | 
					        return ABNF(fin, 0, 0, 0, opcode, 1, data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def format(self) -> bytes:
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Format this object to string(byte array) to send data to server.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        if any(x not in (0, 1) for x in [self.fin, self.rsv1, self.rsv2, self.rsv3]):
 | 
				
			||||||
 | 
					            raise ValueError("not 0 or 1")
 | 
				
			||||||
 | 
					        if self.opcode not in ABNF.OPCODES:
 | 
				
			||||||
 | 
					            raise ValueError("Invalid OPCODE")
 | 
				
			||||||
 | 
					        length = len(self.data)
 | 
				
			||||||
 | 
					        if length >= ABNF.LENGTH_63:
 | 
				
			||||||
 | 
					            raise ValueError("data is too long")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        frame_header = chr(
 | 
				
			||||||
 | 
					            self.fin << 7
 | 
				
			||||||
 | 
					            | self.rsv1 << 6
 | 
				
			||||||
 | 
					            | self.rsv2 << 5
 | 
				
			||||||
 | 
					            | self.rsv3 << 4
 | 
				
			||||||
 | 
					            | self.opcode
 | 
				
			||||||
 | 
					        ).encode("latin-1")
 | 
				
			||||||
 | 
					        if length < ABNF.LENGTH_7:
 | 
				
			||||||
 | 
					            frame_header += chr(self.mask_value << 7 | length).encode("latin-1")
 | 
				
			||||||
 | 
					        elif length < ABNF.LENGTH_16:
 | 
				
			||||||
 | 
					            frame_header += chr(self.mask_value << 7 | 0x7E).encode("latin-1")
 | 
				
			||||||
 | 
					            frame_header += struct.pack("!H", length)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            frame_header += chr(self.mask_value << 7 | 0x7F).encode("latin-1")
 | 
				
			||||||
 | 
					            frame_header += struct.pack("!Q", length)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if not self.mask_value:
 | 
				
			||||||
 | 
					            if isinstance(self.data, str):
 | 
				
			||||||
 | 
					                self.data = self.data.encode("utf-8")
 | 
				
			||||||
 | 
					            return frame_header + self.data
 | 
				
			||||||
 | 
					        mask_key = self.get_mask_key(4)
 | 
				
			||||||
 | 
					        return frame_header + self._get_masked(mask_key)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _get_masked(self, mask_key: Union[str, bytes]) -> bytes:
 | 
				
			||||||
 | 
					        s = ABNF.mask(mask_key, self.data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if isinstance(mask_key, str):
 | 
				
			||||||
 | 
					            mask_key = mask_key.encode("utf-8")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return mask_key + s
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @staticmethod
 | 
				
			||||||
 | 
					    def mask(mask_key: Union[str, bytes], data: Union[str, bytes]) -> bytes:
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Mask or unmask data. Just do xor for each byte
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Parameters
 | 
				
			||||||
 | 
					        ----------
 | 
				
			||||||
 | 
					        mask_key: bytes or str
 | 
				
			||||||
 | 
					            4 byte mask.
 | 
				
			||||||
 | 
					        data: bytes or str
 | 
				
			||||||
 | 
					            data to mask/unmask.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        if data is None:
 | 
				
			||||||
 | 
					            data = ""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if isinstance(mask_key, str):
 | 
				
			||||||
 | 
					            mask_key = mask_key.encode("latin-1")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if isinstance(data, str):
 | 
				
			||||||
 | 
					            data = data.encode("latin-1")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return _mask(array.array("B", mask_key), array.array("B", data))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class frame_buffer:
 | 
				
			||||||
 | 
					    _HEADER_MASK_INDEX = 5
 | 
				
			||||||
 | 
					    _HEADER_LENGTH_INDEX = 6
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					        self, recv_fn: Callable[[int], int], skip_utf8_validation: bool
 | 
				
			||||||
 | 
					    ) -> None:
 | 
				
			||||||
 | 
					        self.recv = recv_fn
 | 
				
			||||||
 | 
					        self.skip_utf8_validation = skip_utf8_validation
 | 
				
			||||||
 | 
					        # Buffers over the packets from the layer beneath until desired amount
 | 
				
			||||||
 | 
					        # bytes of bytes are received.
 | 
				
			||||||
 | 
					        self.recv_buffer: list = []
 | 
				
			||||||
 | 
					        self.clear()
 | 
				
			||||||
 | 
					        self.lock = Lock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def clear(self) -> None:
 | 
				
			||||||
 | 
					        self.header: Optional[tuple] = None
 | 
				
			||||||
 | 
					        self.length: Optional[int] = None
 | 
				
			||||||
 | 
					        self.mask_value: Union[bytes, str, None] = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def has_received_header(self) -> bool:
 | 
				
			||||||
 | 
					        return self.header is None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def recv_header(self) -> None:
 | 
				
			||||||
 | 
					        header = self.recv_strict(2)
 | 
				
			||||||
 | 
					        b1 = header[0]
 | 
				
			||||||
 | 
					        fin = b1 >> 7 & 1
 | 
				
			||||||
 | 
					        rsv1 = b1 >> 6 & 1
 | 
				
			||||||
 | 
					        rsv2 = b1 >> 5 & 1
 | 
				
			||||||
 | 
					        rsv3 = b1 >> 4 & 1
 | 
				
			||||||
 | 
					        opcode = b1 & 0xF
 | 
				
			||||||
 | 
					        b2 = header[1]
 | 
				
			||||||
 | 
					        has_mask = b2 >> 7 & 1
 | 
				
			||||||
 | 
					        length_bits = b2 & 0x7F
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.header = (fin, rsv1, rsv2, rsv3, opcode, has_mask, length_bits)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def has_mask(self) -> Union[bool, int]:
 | 
				
			||||||
 | 
					        if not self.header:
 | 
				
			||||||
 | 
					            return False
 | 
				
			||||||
 | 
					        header_val: int = self.header[frame_buffer._HEADER_MASK_INDEX]
 | 
				
			||||||
 | 
					        return header_val
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def has_received_length(self) -> bool:
 | 
				
			||||||
 | 
					        return self.length is None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def recv_length(self) -> None:
 | 
				
			||||||
 | 
					        bits = self.header[frame_buffer._HEADER_LENGTH_INDEX]
 | 
				
			||||||
 | 
					        length_bits = bits & 0x7F
 | 
				
			||||||
 | 
					        if length_bits == 0x7E:
 | 
				
			||||||
 | 
					            v = self.recv_strict(2)
 | 
				
			||||||
 | 
					            self.length = struct.unpack("!H", v)[0]
 | 
				
			||||||
 | 
					        elif length_bits == 0x7F:
 | 
				
			||||||
 | 
					            v = self.recv_strict(8)
 | 
				
			||||||
 | 
					            self.length = struct.unpack("!Q", v)[0]
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            self.length = length_bits
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def has_received_mask(self) -> bool:
 | 
				
			||||||
 | 
					        return self.mask_value is None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def recv_mask(self) -> None:
 | 
				
			||||||
 | 
					        self.mask_value = self.recv_strict(4) if self.has_mask() else ""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def recv_frame(self) -> ABNF:
 | 
				
			||||||
 | 
					        with self.lock:
 | 
				
			||||||
 | 
					            # Header
 | 
				
			||||||
 | 
					            if self.has_received_header():
 | 
				
			||||||
 | 
					                self.recv_header()
 | 
				
			||||||
 | 
					            (fin, rsv1, rsv2, rsv3, opcode, has_mask, _) = self.header
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # Frame length
 | 
				
			||||||
 | 
					            if self.has_received_length():
 | 
				
			||||||
 | 
					                self.recv_length()
 | 
				
			||||||
 | 
					            length = self.length
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # Mask
 | 
				
			||||||
 | 
					            if self.has_received_mask():
 | 
				
			||||||
 | 
					                self.recv_mask()
 | 
				
			||||||
 | 
					            mask_value = self.mask_value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # Payload
 | 
				
			||||||
 | 
					            payload = self.recv_strict(length)
 | 
				
			||||||
 | 
					            if has_mask:
 | 
				
			||||||
 | 
					                payload = ABNF.mask(mask_value, payload)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # Reset for next frame
 | 
				
			||||||
 | 
					            self.clear()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            frame = ABNF(fin, rsv1, rsv2, rsv3, opcode, has_mask, payload)
 | 
				
			||||||
 | 
					            frame.validate(self.skip_utf8_validation)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return frame
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def recv_strict(self, bufsize: int) -> bytes:
 | 
				
			||||||
 | 
					        shortage = bufsize - sum(map(len, self.recv_buffer))
 | 
				
			||||||
 | 
					        while shortage > 0:
 | 
				
			||||||
 | 
					            # Limit buffer size that we pass to socket.recv() to avoid
 | 
				
			||||||
 | 
					            # fragmenting the heap -- the number of bytes recv() actually
 | 
				
			||||||
 | 
					            # reads is limited by socket buffer and is relatively small,
 | 
				
			||||||
 | 
					            # yet passing large numbers repeatedly causes lots of large
 | 
				
			||||||
 | 
					            # buffers allocated and then shrunk, which results in
 | 
				
			||||||
 | 
					            # fragmentation.
 | 
				
			||||||
 | 
					            bytes_ = self.recv(min(16384, shortage))
 | 
				
			||||||
 | 
					            self.recv_buffer.append(bytes_)
 | 
				
			||||||
 | 
					            shortage -= len(bytes_)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        unified = b"".join(self.recv_buffer)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if shortage == 0:
 | 
				
			||||||
 | 
					            self.recv_buffer = []
 | 
				
			||||||
 | 
					            return unified
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            self.recv_buffer = [unified[bufsize:]]
 | 
				
			||||||
 | 
					            return unified[:bufsize]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class continuous_frame:
 | 
				
			||||||
 | 
					    def __init__(self, fire_cont_frame: bool, skip_utf8_validation: bool) -> None:
 | 
				
			||||||
 | 
					        self.fire_cont_frame = fire_cont_frame
 | 
				
			||||||
 | 
					        self.skip_utf8_validation = skip_utf8_validation
 | 
				
			||||||
 | 
					        self.cont_data: Optional[list] = None
 | 
				
			||||||
 | 
					        self.recving_frames: Optional[int] = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def validate(self, frame: ABNF) -> None:
 | 
				
			||||||
 | 
					        if not self.recving_frames and frame.opcode == ABNF.OPCODE_CONT:
 | 
				
			||||||
 | 
					            raise WebSocketProtocolException("Illegal frame")
 | 
				
			||||||
 | 
					        if self.recving_frames and frame.opcode in (
 | 
				
			||||||
 | 
					            ABNF.OPCODE_TEXT,
 | 
				
			||||||
 | 
					            ABNF.OPCODE_BINARY,
 | 
				
			||||||
 | 
					        ):
 | 
				
			||||||
 | 
					            raise WebSocketProtocolException("Illegal frame")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def add(self, frame: ABNF) -> None:
 | 
				
			||||||
 | 
					        if self.cont_data:
 | 
				
			||||||
 | 
					            self.cont_data[1] += frame.data
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            if frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY):
 | 
				
			||||||
 | 
					                self.recving_frames = frame.opcode
 | 
				
			||||||
 | 
					            self.cont_data = [frame.opcode, frame.data]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if frame.fin:
 | 
				
			||||||
 | 
					            self.recving_frames = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def is_fire(self, frame: ABNF) -> Union[bool, int]:
 | 
				
			||||||
 | 
					        return frame.fin or self.fire_cont_frame
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def extract(self, frame: ABNF) -> tuple:
 | 
				
			||||||
 | 
					        data = self.cont_data
 | 
				
			||||||
 | 
					        self.cont_data = None
 | 
				
			||||||
 | 
					        frame.data = data[1]
 | 
				
			||||||
 | 
					        if (
 | 
				
			||||||
 | 
					            not self.fire_cont_frame
 | 
				
			||||||
 | 
					            and data[0] == ABNF.OPCODE_TEXT
 | 
				
			||||||
 | 
					            and not self.skip_utf8_validation
 | 
				
			||||||
 | 
					            and not validate_utf8(frame.data)
 | 
				
			||||||
 | 
					        ):
 | 
				
			||||||
 | 
					            raise WebSocketPayloadException(f"cannot decode: {repr(frame.data)}")
 | 
				
			||||||
 | 
					        return data[0], frame
 | 
				
			||||||
							
								
								
									
										677
									
								
								src/libs/websocket/_app.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										677
									
								
								src/libs/websocket/_app.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,677 @@
 | 
				
			|||||||
 | 
					import inspect
 | 
				
			||||||
 | 
					import selectors
 | 
				
			||||||
 | 
					import socket
 | 
				
			||||||
 | 
					import threading
 | 
				
			||||||
 | 
					import time
 | 
				
			||||||
 | 
					from typing import Any, Callable, Optional, Union
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from . import _logging
 | 
				
			||||||
 | 
					from ._abnf import ABNF
 | 
				
			||||||
 | 
					from ._core import WebSocket, getdefaulttimeout
 | 
				
			||||||
 | 
					from ._exceptions import (
 | 
				
			||||||
 | 
					    WebSocketConnectionClosedException,
 | 
				
			||||||
 | 
					    WebSocketException,
 | 
				
			||||||
 | 
					    WebSocketTimeoutException,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					from ._ssl_compat import SSLEOFError
 | 
				
			||||||
 | 
					from ._url import parse_url
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					_app.py
 | 
				
			||||||
 | 
					websocket - WebSocket client library for Python
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Copyright 2024 engn33r
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					You may obtain a copy of the License at
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					limitations under the License.
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					__all__ = ["WebSocketApp"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					RECONNECT = 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def setReconnect(reconnectInterval: int) -> None:
 | 
				
			||||||
 | 
					    global RECONNECT
 | 
				
			||||||
 | 
					    RECONNECT = reconnectInterval
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class DispatcherBase:
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    DispatcherBase
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, app: Any, ping_timeout: Union[float, int, None]) -> None:
 | 
				
			||||||
 | 
					        self.app = app
 | 
				
			||||||
 | 
					        self.ping_timeout = ping_timeout
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def timeout(self, seconds: Union[float, int, None], callback: Callable) -> None:
 | 
				
			||||||
 | 
					        time.sleep(seconds)
 | 
				
			||||||
 | 
					        callback()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def reconnect(self, seconds: int, reconnector: Callable) -> None:
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            _logging.info(
 | 
				
			||||||
 | 
					                f"reconnect() - retrying in {seconds} seconds [{len(inspect.stack())} frames in stack]"
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            time.sleep(seconds)
 | 
				
			||||||
 | 
					            reconnector(reconnecting=True)
 | 
				
			||||||
 | 
					        except KeyboardInterrupt as e:
 | 
				
			||||||
 | 
					            _logging.info(f"User exited {e}")
 | 
				
			||||||
 | 
					            raise e
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Dispatcher(DispatcherBase):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Dispatcher
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def read(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        sock: socket.socket,
 | 
				
			||||||
 | 
					        read_callback: Callable,
 | 
				
			||||||
 | 
					        check_callback: Callable,
 | 
				
			||||||
 | 
					    ) -> None:
 | 
				
			||||||
 | 
					        sel = selectors.DefaultSelector()
 | 
				
			||||||
 | 
					        sel.register(self.app.sock.sock, selectors.EVENT_READ)
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            while self.app.keep_running:
 | 
				
			||||||
 | 
					                if sel.select(self.ping_timeout):
 | 
				
			||||||
 | 
					                    if not read_callback():
 | 
				
			||||||
 | 
					                        break
 | 
				
			||||||
 | 
					                check_callback()
 | 
				
			||||||
 | 
					        finally:
 | 
				
			||||||
 | 
					            sel.close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class SSLDispatcher(DispatcherBase):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    SSLDispatcher
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def read(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        sock: socket.socket,
 | 
				
			||||||
 | 
					        read_callback: Callable,
 | 
				
			||||||
 | 
					        check_callback: Callable,
 | 
				
			||||||
 | 
					    ) -> None:
 | 
				
			||||||
 | 
					        sock = self.app.sock.sock
 | 
				
			||||||
 | 
					        sel = selectors.DefaultSelector()
 | 
				
			||||||
 | 
					        sel.register(sock, selectors.EVENT_READ)
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            while self.app.keep_running:
 | 
				
			||||||
 | 
					                if self.select(sock, sel):
 | 
				
			||||||
 | 
					                    if not read_callback():
 | 
				
			||||||
 | 
					                        break
 | 
				
			||||||
 | 
					                check_callback()
 | 
				
			||||||
 | 
					        finally:
 | 
				
			||||||
 | 
					            sel.close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def select(self, sock, sel: selectors.DefaultSelector):
 | 
				
			||||||
 | 
					        sock = self.app.sock.sock
 | 
				
			||||||
 | 
					        if sock.pending():
 | 
				
			||||||
 | 
					            return [
 | 
				
			||||||
 | 
					                sock,
 | 
				
			||||||
 | 
					            ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        r = sel.select(self.ping_timeout)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if len(r) > 0:
 | 
				
			||||||
 | 
					            return r[0][0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class WrappedDispatcher:
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    WrappedDispatcher
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, app, ping_timeout: Union[float, int, None], dispatcher) -> None:
 | 
				
			||||||
 | 
					        self.app = app
 | 
				
			||||||
 | 
					        self.ping_timeout = ping_timeout
 | 
				
			||||||
 | 
					        self.dispatcher = dispatcher
 | 
				
			||||||
 | 
					        dispatcher.signal(2, dispatcher.abort)  # keyboard interrupt
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def read(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        sock: socket.socket,
 | 
				
			||||||
 | 
					        read_callback: Callable,
 | 
				
			||||||
 | 
					        check_callback: Callable,
 | 
				
			||||||
 | 
					    ) -> None:
 | 
				
			||||||
 | 
					        self.dispatcher.read(sock, read_callback)
 | 
				
			||||||
 | 
					        self.ping_timeout and self.timeout(self.ping_timeout, check_callback)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def timeout(self, seconds: float, callback: Callable) -> None:
 | 
				
			||||||
 | 
					        self.dispatcher.timeout(seconds, callback)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def reconnect(self, seconds: int, reconnector: Callable) -> None:
 | 
				
			||||||
 | 
					        self.timeout(seconds, reconnector)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class WebSocketApp:
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Higher level of APIs are provided. The interface is like JavaScript WebSocket object.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        url: str,
 | 
				
			||||||
 | 
					        header: Union[list, dict, Callable, None] = None,
 | 
				
			||||||
 | 
					        on_open: Optional[Callable[[WebSocket], None]] = None,
 | 
				
			||||||
 | 
					        on_reconnect: Optional[Callable[[WebSocket], None]] = None,
 | 
				
			||||||
 | 
					        on_message: Optional[Callable[[WebSocket, Any], None]] = None,
 | 
				
			||||||
 | 
					        on_error: Optional[Callable[[WebSocket, Any], None]] = None,
 | 
				
			||||||
 | 
					        on_close: Optional[Callable[[WebSocket, Any, Any], None]] = None,
 | 
				
			||||||
 | 
					        on_ping: Optional[Callable] = None,
 | 
				
			||||||
 | 
					        on_pong: Optional[Callable] = None,
 | 
				
			||||||
 | 
					        on_cont_message: Optional[Callable] = None,
 | 
				
			||||||
 | 
					        keep_running: bool = True,
 | 
				
			||||||
 | 
					        get_mask_key: Optional[Callable] = None,
 | 
				
			||||||
 | 
					        cookie: Optional[str] = None,
 | 
				
			||||||
 | 
					        subprotocols: Optional[list] = None,
 | 
				
			||||||
 | 
					        on_data: Optional[Callable] = None,
 | 
				
			||||||
 | 
					        socket: Optional[socket.socket] = None,
 | 
				
			||||||
 | 
					    ) -> None:
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        WebSocketApp initialization
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Parameters
 | 
				
			||||||
 | 
					        ----------
 | 
				
			||||||
 | 
					        url: str
 | 
				
			||||||
 | 
					            Websocket url.
 | 
				
			||||||
 | 
					        header: list or dict or Callable
 | 
				
			||||||
 | 
					            Custom header for websocket handshake.
 | 
				
			||||||
 | 
					            If the parameter is a callable object, it is called just before the connection attempt.
 | 
				
			||||||
 | 
					            The returned dict or list is used as custom header value.
 | 
				
			||||||
 | 
					            This could be useful in order to properly setup timestamp dependent headers.
 | 
				
			||||||
 | 
					        on_open: function
 | 
				
			||||||
 | 
					            Callback object which is called at opening websocket.
 | 
				
			||||||
 | 
					            on_open has one argument.
 | 
				
			||||||
 | 
					            The 1st argument is this class object.
 | 
				
			||||||
 | 
					        on_reconnect: function
 | 
				
			||||||
 | 
					            Callback object which is called at reconnecting websocket.
 | 
				
			||||||
 | 
					            on_reconnect has one argument.
 | 
				
			||||||
 | 
					            The 1st argument is this class object.
 | 
				
			||||||
 | 
					        on_message: function
 | 
				
			||||||
 | 
					            Callback object which is called when received data.
 | 
				
			||||||
 | 
					            on_message has 2 arguments.
 | 
				
			||||||
 | 
					            The 1st argument is this class object.
 | 
				
			||||||
 | 
					            The 2nd argument is utf-8 data received from the server.
 | 
				
			||||||
 | 
					        on_error: function
 | 
				
			||||||
 | 
					            Callback object which is called when we get error.
 | 
				
			||||||
 | 
					            on_error has 2 arguments.
 | 
				
			||||||
 | 
					            The 1st argument is this class object.
 | 
				
			||||||
 | 
					            The 2nd argument is exception object.
 | 
				
			||||||
 | 
					        on_close: function
 | 
				
			||||||
 | 
					            Callback object which is called when connection is closed.
 | 
				
			||||||
 | 
					            on_close has 3 arguments.
 | 
				
			||||||
 | 
					            The 1st argument is this class object.
 | 
				
			||||||
 | 
					            The 2nd argument is close_status_code.
 | 
				
			||||||
 | 
					            The 3rd argument is close_msg.
 | 
				
			||||||
 | 
					        on_cont_message: function
 | 
				
			||||||
 | 
					            Callback object which is called when a continuation
 | 
				
			||||||
 | 
					            frame is received.
 | 
				
			||||||
 | 
					            on_cont_message has 3 arguments.
 | 
				
			||||||
 | 
					            The 1st argument is this class object.
 | 
				
			||||||
 | 
					            The 2nd argument is utf-8 string which we get from the server.
 | 
				
			||||||
 | 
					            The 3rd argument is continue flag. if 0, the data continue
 | 
				
			||||||
 | 
					            to next frame data
 | 
				
			||||||
 | 
					        on_data: function
 | 
				
			||||||
 | 
					            Callback object which is called when a message received.
 | 
				
			||||||
 | 
					            This is called before on_message or on_cont_message,
 | 
				
			||||||
 | 
					            and then on_message or on_cont_message is called.
 | 
				
			||||||
 | 
					            on_data has 4 argument.
 | 
				
			||||||
 | 
					            The 1st argument is this class object.
 | 
				
			||||||
 | 
					            The 2nd argument is utf-8 string which we get from the server.
 | 
				
			||||||
 | 
					            The 3rd argument is data type. ABNF.OPCODE_TEXT or ABNF.OPCODE_BINARY will be came.
 | 
				
			||||||
 | 
					            The 4th argument is continue flag. If 0, the data continue
 | 
				
			||||||
 | 
					        keep_running: bool
 | 
				
			||||||
 | 
					            This parameter is obsolete and ignored.
 | 
				
			||||||
 | 
					        get_mask_key: function
 | 
				
			||||||
 | 
					            A callable function to get new mask keys, see the
 | 
				
			||||||
 | 
					            WebSocket.set_mask_key's docstring for more information.
 | 
				
			||||||
 | 
					        cookie: str
 | 
				
			||||||
 | 
					            Cookie value.
 | 
				
			||||||
 | 
					        subprotocols: list
 | 
				
			||||||
 | 
					            List of available sub protocols. Default is None.
 | 
				
			||||||
 | 
					        socket: socket
 | 
				
			||||||
 | 
					            Pre-initialized stream socket.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        self.url = url
 | 
				
			||||||
 | 
					        self.header = header if header is not None else []
 | 
				
			||||||
 | 
					        self.cookie = cookie
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.on_open = on_open
 | 
				
			||||||
 | 
					        self.on_reconnect = on_reconnect
 | 
				
			||||||
 | 
					        self.on_message = on_message
 | 
				
			||||||
 | 
					        self.on_data = on_data
 | 
				
			||||||
 | 
					        self.on_error = on_error
 | 
				
			||||||
 | 
					        self.on_close = on_close
 | 
				
			||||||
 | 
					        self.on_ping = on_ping
 | 
				
			||||||
 | 
					        self.on_pong = on_pong
 | 
				
			||||||
 | 
					        self.on_cont_message = on_cont_message
 | 
				
			||||||
 | 
					        self.keep_running = False
 | 
				
			||||||
 | 
					        self.get_mask_key = get_mask_key
 | 
				
			||||||
 | 
					        self.sock: Optional[WebSocket] = None
 | 
				
			||||||
 | 
					        self.last_ping_tm = float(0)
 | 
				
			||||||
 | 
					        self.last_pong_tm = float(0)
 | 
				
			||||||
 | 
					        self.ping_thread: Optional[threading.Thread] = None
 | 
				
			||||||
 | 
					        self.stop_ping: Optional[threading.Event] = None
 | 
				
			||||||
 | 
					        self.ping_interval = float(0)
 | 
				
			||||||
 | 
					        self.ping_timeout: Union[float, int, None] = None
 | 
				
			||||||
 | 
					        self.ping_payload = ""
 | 
				
			||||||
 | 
					        self.subprotocols = subprotocols
 | 
				
			||||||
 | 
					        self.prepared_socket = socket
 | 
				
			||||||
 | 
					        self.has_errored = False
 | 
				
			||||||
 | 
					        self.has_done_teardown = False
 | 
				
			||||||
 | 
					        self.has_done_teardown_lock = threading.Lock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def send(self, data: Union[bytes, str], opcode: int = ABNF.OPCODE_TEXT) -> None:
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        send message
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Parameters
 | 
				
			||||||
 | 
					        ----------
 | 
				
			||||||
 | 
					        data: str
 | 
				
			||||||
 | 
					            Message to send. If you set opcode to OPCODE_TEXT,
 | 
				
			||||||
 | 
					            data must be utf-8 string or unicode.
 | 
				
			||||||
 | 
					        opcode: int
 | 
				
			||||||
 | 
					            Operation code of data. Default is OPCODE_TEXT.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if not self.sock or self.sock.send(data, opcode) == 0:
 | 
				
			||||||
 | 
					            raise WebSocketConnectionClosedException("Connection is already closed.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def send_text(self, text_data: str) -> None:
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Sends UTF-8 encoded text.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        if not self.sock or self.sock.send(text_data, ABNF.OPCODE_TEXT) == 0:
 | 
				
			||||||
 | 
					            raise WebSocketConnectionClosedException("Connection is already closed.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def send_bytes(self, data: Union[bytes, bytearray]) -> None:
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Sends a sequence of bytes.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        if not self.sock or self.sock.send(data, ABNF.OPCODE_BINARY) == 0:
 | 
				
			||||||
 | 
					            raise WebSocketConnectionClosedException("Connection is already closed.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def close(self, **kwargs) -> None:
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Close websocket connection.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        self.keep_running = False
 | 
				
			||||||
 | 
					        if self.sock:
 | 
				
			||||||
 | 
					            self.sock.close(**kwargs)
 | 
				
			||||||
 | 
					            self.sock = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _start_ping_thread(self) -> None:
 | 
				
			||||||
 | 
					        self.last_ping_tm = self.last_pong_tm = float(0)
 | 
				
			||||||
 | 
					        self.stop_ping = threading.Event()
 | 
				
			||||||
 | 
					        self.ping_thread = threading.Thread(target=self._send_ping)
 | 
				
			||||||
 | 
					        self.ping_thread.daemon = True
 | 
				
			||||||
 | 
					        self.ping_thread.start()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _stop_ping_thread(self) -> None:
 | 
				
			||||||
 | 
					        if self.stop_ping:
 | 
				
			||||||
 | 
					            self.stop_ping.set()
 | 
				
			||||||
 | 
					        if self.ping_thread and self.ping_thread.is_alive():
 | 
				
			||||||
 | 
					            self.ping_thread.join(3)
 | 
				
			||||||
 | 
					        self.last_ping_tm = self.last_pong_tm = float(0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _send_ping(self) -> None:
 | 
				
			||||||
 | 
					        if self.stop_ping.wait(self.ping_interval) or self.keep_running is False:
 | 
				
			||||||
 | 
					            return
 | 
				
			||||||
 | 
					        while not self.stop_ping.wait(self.ping_interval) and self.keep_running is True:
 | 
				
			||||||
 | 
					            if self.sock:
 | 
				
			||||||
 | 
					                self.last_ping_tm = time.time()
 | 
				
			||||||
 | 
					                try:
 | 
				
			||||||
 | 
					                    _logging.debug("Sending ping")
 | 
				
			||||||
 | 
					                    self.sock.ping(self.ping_payload)
 | 
				
			||||||
 | 
					                except Exception as e:
 | 
				
			||||||
 | 
					                    _logging.debug(f"Failed to send ping: {e}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def run_forever(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        sockopt: tuple = None,
 | 
				
			||||||
 | 
					        sslopt: dict = None,
 | 
				
			||||||
 | 
					        ping_interval: Union[float, int] = 0,
 | 
				
			||||||
 | 
					        ping_timeout: Union[float, int, None] = None,
 | 
				
			||||||
 | 
					        ping_payload: str = "",
 | 
				
			||||||
 | 
					        http_proxy_host: str = None,
 | 
				
			||||||
 | 
					        http_proxy_port: Union[int, str] = None,
 | 
				
			||||||
 | 
					        http_no_proxy: list = None,
 | 
				
			||||||
 | 
					        http_proxy_auth: tuple = None,
 | 
				
			||||||
 | 
					        http_proxy_timeout: Optional[float] = None,
 | 
				
			||||||
 | 
					        skip_utf8_validation: bool = False,
 | 
				
			||||||
 | 
					        host: str = None,
 | 
				
			||||||
 | 
					        origin: str = None,
 | 
				
			||||||
 | 
					        dispatcher=None,
 | 
				
			||||||
 | 
					        suppress_origin: bool = False,
 | 
				
			||||||
 | 
					        proxy_type: str = None,
 | 
				
			||||||
 | 
					        reconnect: int = None,
 | 
				
			||||||
 | 
					    ) -> bool:
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Run event loop for WebSocket framework.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        This loop is an infinite loop and is alive while websocket is available.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Parameters
 | 
				
			||||||
 | 
					        ----------
 | 
				
			||||||
 | 
					        sockopt: tuple
 | 
				
			||||||
 | 
					            Values for socket.setsockopt.
 | 
				
			||||||
 | 
					            sockopt must be tuple
 | 
				
			||||||
 | 
					            and each element is argument of sock.setsockopt.
 | 
				
			||||||
 | 
					        sslopt: dict
 | 
				
			||||||
 | 
					            Optional dict object for ssl socket option.
 | 
				
			||||||
 | 
					        ping_interval: int or float
 | 
				
			||||||
 | 
					            Automatically send "ping" command
 | 
				
			||||||
 | 
					            every specified period (in seconds).
 | 
				
			||||||
 | 
					            If set to 0, no ping is sent periodically.
 | 
				
			||||||
 | 
					        ping_timeout: int or float
 | 
				
			||||||
 | 
					            Timeout (in seconds) if the pong message is not received.
 | 
				
			||||||
 | 
					        ping_payload: str
 | 
				
			||||||
 | 
					            Payload message to send with each ping.
 | 
				
			||||||
 | 
					        http_proxy_host: str
 | 
				
			||||||
 | 
					            HTTP proxy host name.
 | 
				
			||||||
 | 
					        http_proxy_port: int or str
 | 
				
			||||||
 | 
					            HTTP proxy port. If not set, set to 80.
 | 
				
			||||||
 | 
					        http_no_proxy: list
 | 
				
			||||||
 | 
					            Whitelisted host names that don't use the proxy.
 | 
				
			||||||
 | 
					        http_proxy_timeout: int or float
 | 
				
			||||||
 | 
					            HTTP proxy timeout, default is 60 sec as per python-socks.
 | 
				
			||||||
 | 
					        http_proxy_auth: tuple
 | 
				
			||||||
 | 
					            HTTP proxy auth information. tuple of username and password. Default is None.
 | 
				
			||||||
 | 
					        skip_utf8_validation: bool
 | 
				
			||||||
 | 
					            skip utf8 validation.
 | 
				
			||||||
 | 
					        host: str
 | 
				
			||||||
 | 
					            update host header.
 | 
				
			||||||
 | 
					        origin: str
 | 
				
			||||||
 | 
					            update origin header.
 | 
				
			||||||
 | 
					        dispatcher: Dispatcher object
 | 
				
			||||||
 | 
					            customize reading data from socket.
 | 
				
			||||||
 | 
					        suppress_origin: bool
 | 
				
			||||||
 | 
					            suppress outputting origin header.
 | 
				
			||||||
 | 
					        proxy_type: str
 | 
				
			||||||
 | 
					            type of proxy from: http, socks4, socks4a, socks5, socks5h
 | 
				
			||||||
 | 
					        reconnect: int
 | 
				
			||||||
 | 
					            delay interval when reconnecting
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Returns
 | 
				
			||||||
 | 
					        -------
 | 
				
			||||||
 | 
					        teardown: bool
 | 
				
			||||||
 | 
					            False if the `WebSocketApp` is closed or caught KeyboardInterrupt,
 | 
				
			||||||
 | 
					            True if any other exception was raised during a loop.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if reconnect is None:
 | 
				
			||||||
 | 
					            reconnect = RECONNECT
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if ping_timeout is not None and ping_timeout <= 0:
 | 
				
			||||||
 | 
					            raise WebSocketException("Ensure ping_timeout > 0")
 | 
				
			||||||
 | 
					        if ping_interval is not None and ping_interval < 0:
 | 
				
			||||||
 | 
					            raise WebSocketException("Ensure ping_interval >= 0")
 | 
				
			||||||
 | 
					        if ping_timeout and ping_interval and ping_interval <= ping_timeout:
 | 
				
			||||||
 | 
					            raise WebSocketException("Ensure ping_interval > ping_timeout")
 | 
				
			||||||
 | 
					        if not sockopt:
 | 
				
			||||||
 | 
					            sockopt = ()
 | 
				
			||||||
 | 
					        if not sslopt:
 | 
				
			||||||
 | 
					            sslopt = {}
 | 
				
			||||||
 | 
					        if self.sock:
 | 
				
			||||||
 | 
					            raise WebSocketException("socket is already opened")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.ping_interval = ping_interval
 | 
				
			||||||
 | 
					        self.ping_timeout = ping_timeout
 | 
				
			||||||
 | 
					        self.ping_payload = ping_payload
 | 
				
			||||||
 | 
					        self.has_done_teardown = False
 | 
				
			||||||
 | 
					        self.keep_running = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        def teardown(close_frame: ABNF = None):
 | 
				
			||||||
 | 
					            """
 | 
				
			||||||
 | 
					            Tears down the connection.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            Parameters
 | 
				
			||||||
 | 
					            ----------
 | 
				
			||||||
 | 
					            close_frame: ABNF frame
 | 
				
			||||||
 | 
					                If close_frame is set, the on_close handler is invoked
 | 
				
			||||||
 | 
					                with the statusCode and reason from the provided frame.
 | 
				
			||||||
 | 
					            """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # teardown() is called in many code paths to ensure resources are cleaned up and on_close is fired.
 | 
				
			||||||
 | 
					            # To ensure the work is only done once, we use this bool and lock.
 | 
				
			||||||
 | 
					            with self.has_done_teardown_lock:
 | 
				
			||||||
 | 
					                if self.has_done_teardown:
 | 
				
			||||||
 | 
					                    return
 | 
				
			||||||
 | 
					                self.has_done_teardown = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            self._stop_ping_thread()
 | 
				
			||||||
 | 
					            self.keep_running = False
 | 
				
			||||||
 | 
					            if self.sock:
 | 
				
			||||||
 | 
					                self.sock.close()
 | 
				
			||||||
 | 
					            close_status_code, close_reason = self._get_close_args(
 | 
				
			||||||
 | 
					                close_frame if close_frame else None
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            self.sock = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # Finally call the callback AFTER all teardown is complete
 | 
				
			||||||
 | 
					            self._callback(self.on_close, close_status_code, close_reason)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        def setSock(reconnecting: bool = False) -> None:
 | 
				
			||||||
 | 
					            if reconnecting and self.sock:
 | 
				
			||||||
 | 
					                self.sock.shutdown()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            self.sock = WebSocket(
 | 
				
			||||||
 | 
					                self.get_mask_key,
 | 
				
			||||||
 | 
					                sockopt=sockopt,
 | 
				
			||||||
 | 
					                sslopt=sslopt,
 | 
				
			||||||
 | 
					                fire_cont_frame=self.on_cont_message is not None,
 | 
				
			||||||
 | 
					                skip_utf8_validation=skip_utf8_validation,
 | 
				
			||||||
 | 
					                enable_multithread=True,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            self.sock.settimeout(getdefaulttimeout())
 | 
				
			||||||
 | 
					            try:
 | 
				
			||||||
 | 
					                header = self.header() if callable(self.header) else self.header
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                self.sock.connect(
 | 
				
			||||||
 | 
					                    self.url,
 | 
				
			||||||
 | 
					                    header=header,
 | 
				
			||||||
 | 
					                    cookie=self.cookie,
 | 
				
			||||||
 | 
					                    http_proxy_host=http_proxy_host,
 | 
				
			||||||
 | 
					                    http_proxy_port=http_proxy_port,
 | 
				
			||||||
 | 
					                    http_no_proxy=http_no_proxy,
 | 
				
			||||||
 | 
					                    http_proxy_auth=http_proxy_auth,
 | 
				
			||||||
 | 
					                    http_proxy_timeout=http_proxy_timeout,
 | 
				
			||||||
 | 
					                    subprotocols=self.subprotocols,
 | 
				
			||||||
 | 
					                    host=host,
 | 
				
			||||||
 | 
					                    origin=origin,
 | 
				
			||||||
 | 
					                    suppress_origin=suppress_origin,
 | 
				
			||||||
 | 
					                    proxy_type=proxy_type,
 | 
				
			||||||
 | 
					                    socket=self.prepared_socket,
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                _logging.info("Websocket connected")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                if self.ping_interval:
 | 
				
			||||||
 | 
					                    self._start_ping_thread()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                if reconnecting and self.on_reconnect:
 | 
				
			||||||
 | 
					                    self._callback(self.on_reconnect)
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
 | 
					                    self._callback(self.on_open)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                dispatcher.read(self.sock.sock, read, check)
 | 
				
			||||||
 | 
					            except (
 | 
				
			||||||
 | 
					                WebSocketConnectionClosedException,
 | 
				
			||||||
 | 
					                ConnectionRefusedError,
 | 
				
			||||||
 | 
					                KeyboardInterrupt,
 | 
				
			||||||
 | 
					                SystemExit,
 | 
				
			||||||
 | 
					                Exception,
 | 
				
			||||||
 | 
					            ) as e:
 | 
				
			||||||
 | 
					                handleDisconnect(e, reconnecting)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        def read() -> bool:
 | 
				
			||||||
 | 
					            if not self.keep_running:
 | 
				
			||||||
 | 
					                return teardown()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            try:
 | 
				
			||||||
 | 
					                op_code, frame = self.sock.recv_data_frame(True)
 | 
				
			||||||
 | 
					            except (
 | 
				
			||||||
 | 
					                WebSocketConnectionClosedException,
 | 
				
			||||||
 | 
					                KeyboardInterrupt,
 | 
				
			||||||
 | 
					                SSLEOFError,
 | 
				
			||||||
 | 
					            ) as e:
 | 
				
			||||||
 | 
					                if custom_dispatcher:
 | 
				
			||||||
 | 
					                    return handleDisconnect(e, bool(reconnect))
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
 | 
					                    raise e
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if op_code == ABNF.OPCODE_CLOSE:
 | 
				
			||||||
 | 
					                return teardown(frame)
 | 
				
			||||||
 | 
					            elif op_code == ABNF.OPCODE_PING:
 | 
				
			||||||
 | 
					                self._callback(self.on_ping, frame.data)
 | 
				
			||||||
 | 
					            elif op_code == ABNF.OPCODE_PONG:
 | 
				
			||||||
 | 
					                self.last_pong_tm = time.time()
 | 
				
			||||||
 | 
					                self._callback(self.on_pong, frame.data)
 | 
				
			||||||
 | 
					            elif op_code == ABNF.OPCODE_CONT and self.on_cont_message:
 | 
				
			||||||
 | 
					                self._callback(self.on_data, frame.data, frame.opcode, frame.fin)
 | 
				
			||||||
 | 
					                self._callback(self.on_cont_message, frame.data, frame.fin)
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                data = frame.data
 | 
				
			||||||
 | 
					                if op_code == ABNF.OPCODE_TEXT and not skip_utf8_validation:
 | 
				
			||||||
 | 
					                    data = data.decode("utf-8")
 | 
				
			||||||
 | 
					                self._callback(self.on_data, data, frame.opcode, True)
 | 
				
			||||||
 | 
					                self._callback(self.on_message, data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            return True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        def check() -> bool:
 | 
				
			||||||
 | 
					            if self.ping_timeout:
 | 
				
			||||||
 | 
					                has_timeout_expired = (
 | 
				
			||||||
 | 
					                    time.time() - self.last_ping_tm > self.ping_timeout
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					                has_pong_not_arrived_after_last_ping = (
 | 
				
			||||||
 | 
					                    self.last_pong_tm - self.last_ping_tm < 0
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					                has_pong_arrived_too_late = (
 | 
				
			||||||
 | 
					                    self.last_pong_tm - self.last_ping_tm > self.ping_timeout
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                if (
 | 
				
			||||||
 | 
					                    self.last_ping_tm
 | 
				
			||||||
 | 
					                    and has_timeout_expired
 | 
				
			||||||
 | 
					                    and (
 | 
				
			||||||
 | 
					                        has_pong_not_arrived_after_last_ping
 | 
				
			||||||
 | 
					                        or has_pong_arrived_too_late
 | 
				
			||||||
 | 
					                    )
 | 
				
			||||||
 | 
					                ):
 | 
				
			||||||
 | 
					                    raise WebSocketTimeoutException("ping/pong timed out")
 | 
				
			||||||
 | 
					            return True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        def handleDisconnect(
 | 
				
			||||||
 | 
					            e: Union[
 | 
				
			||||||
 | 
					                WebSocketConnectionClosedException,
 | 
				
			||||||
 | 
					                ConnectionRefusedError,
 | 
				
			||||||
 | 
					                KeyboardInterrupt,
 | 
				
			||||||
 | 
					                SystemExit,
 | 
				
			||||||
 | 
					                Exception,
 | 
				
			||||||
 | 
					            ],
 | 
				
			||||||
 | 
					            reconnecting: bool = False,
 | 
				
			||||||
 | 
					        ) -> bool:
 | 
				
			||||||
 | 
					            self.has_errored = True
 | 
				
			||||||
 | 
					            self._stop_ping_thread()
 | 
				
			||||||
 | 
					            if not reconnecting:
 | 
				
			||||||
 | 
					                self._callback(self.on_error, e)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if isinstance(e, (KeyboardInterrupt, SystemExit)):
 | 
				
			||||||
 | 
					                teardown()
 | 
				
			||||||
 | 
					                # Propagate further
 | 
				
			||||||
 | 
					                raise
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if reconnect:
 | 
				
			||||||
 | 
					                _logging.info(f"{e} - reconnect")
 | 
				
			||||||
 | 
					                if custom_dispatcher:
 | 
				
			||||||
 | 
					                    _logging.debug(
 | 
				
			||||||
 | 
					                        f"Calling custom dispatcher reconnect [{len(inspect.stack())} frames in stack]"
 | 
				
			||||||
 | 
					                    )
 | 
				
			||||||
 | 
					                    dispatcher.reconnect(reconnect, setSock)
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                _logging.error(f"{e} - goodbye")
 | 
				
			||||||
 | 
					                teardown()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        custom_dispatcher = bool(dispatcher)
 | 
				
			||||||
 | 
					        dispatcher = self.create_dispatcher(
 | 
				
			||||||
 | 
					            ping_timeout, dispatcher, parse_url(self.url)[3]
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            setSock()
 | 
				
			||||||
 | 
					            if not custom_dispatcher and reconnect:
 | 
				
			||||||
 | 
					                while self.keep_running:
 | 
				
			||||||
 | 
					                    _logging.debug(
 | 
				
			||||||
 | 
					                        f"Calling dispatcher reconnect [{len(inspect.stack())} frames in stack]"
 | 
				
			||||||
 | 
					                    )
 | 
				
			||||||
 | 
					                    dispatcher.reconnect(reconnect, setSock)
 | 
				
			||||||
 | 
					        except (KeyboardInterrupt, Exception) as e:
 | 
				
			||||||
 | 
					            _logging.info(f"tearing down on exception {e}")
 | 
				
			||||||
 | 
					            teardown()
 | 
				
			||||||
 | 
					        finally:
 | 
				
			||||||
 | 
					            if not custom_dispatcher:
 | 
				
			||||||
 | 
					                # Ensure teardown was called before returning from run_forever
 | 
				
			||||||
 | 
					                teardown()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return self.has_errored
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def create_dispatcher(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        ping_timeout: Union[float, int, None],
 | 
				
			||||||
 | 
					        dispatcher: Optional[DispatcherBase] = None,
 | 
				
			||||||
 | 
					        is_ssl: bool = False,
 | 
				
			||||||
 | 
					    ) -> Union[Dispatcher, SSLDispatcher, WrappedDispatcher]:
 | 
				
			||||||
 | 
					        if dispatcher:  # If custom dispatcher is set, use WrappedDispatcher
 | 
				
			||||||
 | 
					            return WrappedDispatcher(self, ping_timeout, dispatcher)
 | 
				
			||||||
 | 
					        timeout = ping_timeout or 10
 | 
				
			||||||
 | 
					        if is_ssl:
 | 
				
			||||||
 | 
					            return SSLDispatcher(self, timeout)
 | 
				
			||||||
 | 
					        return Dispatcher(self, timeout)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _get_close_args(self, close_frame: ABNF) -> list:
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        _get_close_args extracts the close code and reason from the close body
 | 
				
			||||||
 | 
					        if it exists (RFC6455 says WebSocket Connection Close Code is optional)
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        # Need to catch the case where close_frame is None
 | 
				
			||||||
 | 
					        # Otherwise the following if statement causes an error
 | 
				
			||||||
 | 
					        if not self.on_close or not close_frame:
 | 
				
			||||||
 | 
					            return [None, None]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Extract close frame status code
 | 
				
			||||||
 | 
					        if close_frame.data and len(close_frame.data) >= 2:
 | 
				
			||||||
 | 
					            close_status_code = 256 * int(close_frame.data[0]) + int(
 | 
				
			||||||
 | 
					                close_frame.data[1]
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            reason = close_frame.data[2:]
 | 
				
			||||||
 | 
					            if isinstance(reason, bytes):
 | 
				
			||||||
 | 
					                reason = reason.decode("utf-8")
 | 
				
			||||||
 | 
					            return [close_status_code, reason]
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            # Most likely reached this because len(close_frame_data.data) < 2
 | 
				
			||||||
 | 
					            return [None, None]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _callback(self, callback, *args) -> None:
 | 
				
			||||||
 | 
					        if callback:
 | 
				
			||||||
 | 
					            try:
 | 
				
			||||||
 | 
					                callback(self, *args)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            except Exception as e:
 | 
				
			||||||
 | 
					                _logging.error(f"error from callback {callback}: {e}")
 | 
				
			||||||
 | 
					                if self.on_error:
 | 
				
			||||||
 | 
					                    self.on_error(self, e)
 | 
				
			||||||
							
								
								
									
										75
									
								
								src/libs/websocket/_cookiejar.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										75
									
								
								src/libs/websocket/_cookiejar.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,75 @@
 | 
				
			|||||||
 | 
					import http.cookies
 | 
				
			||||||
 | 
					from typing import Optional
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					_cookiejar.py
 | 
				
			||||||
 | 
					websocket - WebSocket client library for Python
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Copyright 2024 engn33r
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					You may obtain a copy of the License at
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					limitations under the License.
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class SimpleCookieJar:
 | 
				
			||||||
 | 
					    def __init__(self) -> None:
 | 
				
			||||||
 | 
					        self.jar: dict = {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def add(self, set_cookie: Optional[str]) -> None:
 | 
				
			||||||
 | 
					        if set_cookie:
 | 
				
			||||||
 | 
					            simple_cookie = http.cookies.SimpleCookie(set_cookie)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            for v in simple_cookie.values():
 | 
				
			||||||
 | 
					                if domain := v.get("domain"):
 | 
				
			||||||
 | 
					                    if not domain.startswith("."):
 | 
				
			||||||
 | 
					                        domain = f".{domain}"
 | 
				
			||||||
 | 
					                    cookie = (
 | 
				
			||||||
 | 
					                        self.jar.get(domain)
 | 
				
			||||||
 | 
					                        if self.jar.get(domain)
 | 
				
			||||||
 | 
					                        else http.cookies.SimpleCookie()
 | 
				
			||||||
 | 
					                    )
 | 
				
			||||||
 | 
					                    cookie.update(simple_cookie)
 | 
				
			||||||
 | 
					                    self.jar[domain.lower()] = cookie
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def set(self, set_cookie: str) -> None:
 | 
				
			||||||
 | 
					        if set_cookie:
 | 
				
			||||||
 | 
					            simple_cookie = http.cookies.SimpleCookie(set_cookie)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            for v in simple_cookie.values():
 | 
				
			||||||
 | 
					                if domain := v.get("domain"):
 | 
				
			||||||
 | 
					                    if not domain.startswith("."):
 | 
				
			||||||
 | 
					                        domain = f".{domain}"
 | 
				
			||||||
 | 
					                    self.jar[domain.lower()] = simple_cookie
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get(self, host: str) -> str:
 | 
				
			||||||
 | 
					        if not host:
 | 
				
			||||||
 | 
					            return ""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        cookies = []
 | 
				
			||||||
 | 
					        for domain, _ in self.jar.items():
 | 
				
			||||||
 | 
					            host = host.lower()
 | 
				
			||||||
 | 
					            if host.endswith(domain) or host == domain[1:]:
 | 
				
			||||||
 | 
					                cookies.append(self.jar.get(domain))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return "; ".join(
 | 
				
			||||||
 | 
					            filter(
 | 
				
			||||||
 | 
					                None,
 | 
				
			||||||
 | 
					                sorted(
 | 
				
			||||||
 | 
					                    [
 | 
				
			||||||
 | 
					                        f"{k}={v.value}"
 | 
				
			||||||
 | 
					                        for cookie in filter(None, cookies)
 | 
				
			||||||
 | 
					                        for k, v in cookie.items()
 | 
				
			||||||
 | 
					                    ]
 | 
				
			||||||
 | 
					                ),
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
							
								
								
									
										647
									
								
								src/libs/websocket/_core.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										647
									
								
								src/libs/websocket/_core.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,647 @@
 | 
				
			|||||||
 | 
					import socket
 | 
				
			||||||
 | 
					import struct
 | 
				
			||||||
 | 
					import threading
 | 
				
			||||||
 | 
					import time
 | 
				
			||||||
 | 
					from typing import Optional, Union
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# websocket modules
 | 
				
			||||||
 | 
					from ._abnf import ABNF, STATUS_NORMAL, continuous_frame, frame_buffer
 | 
				
			||||||
 | 
					from ._exceptions import WebSocketProtocolException, WebSocketConnectionClosedException
 | 
				
			||||||
 | 
					from ._handshake import SUPPORTED_REDIRECT_STATUSES, handshake
 | 
				
			||||||
 | 
					from ._http import connect, proxy_info
 | 
				
			||||||
 | 
					from ._logging import debug, error, trace, isEnabledForError, isEnabledForTrace
 | 
				
			||||||
 | 
					from ._socket import getdefaulttimeout, recv, send, sock_opt
 | 
				
			||||||
 | 
					from ._ssl_compat import ssl
 | 
				
			||||||
 | 
					from ._utils import NoLock
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					_core.py
 | 
				
			||||||
 | 
					websocket - WebSocket client library for Python
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Copyright 2024 engn33r
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					You may obtain a copy of the License at
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					limitations under the License.
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					__all__ = ["WebSocket", "create_connection"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class WebSocket:
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Low level WebSocket interface.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    This class is based on the WebSocket protocol `draft-hixie-thewebsocketprotocol-76 <http://tools.ietf.org/html/draft-hixie-thewebsocketprotocol-76>`_
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    We can connect to the websocket server and send/receive data.
 | 
				
			||||||
 | 
					    The following example is an echo client.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    >>> import websocket
 | 
				
			||||||
 | 
					    >>> ws = websocket.WebSocket()
 | 
				
			||||||
 | 
					    >>> ws.connect("ws://echo.websocket.events")
 | 
				
			||||||
 | 
					    >>> ws.recv()
 | 
				
			||||||
 | 
					    'echo.websocket.events sponsored by Lob.com'
 | 
				
			||||||
 | 
					    >>> ws.send("Hello, Server")
 | 
				
			||||||
 | 
					    19
 | 
				
			||||||
 | 
					    >>> ws.recv()
 | 
				
			||||||
 | 
					    'Hello, Server'
 | 
				
			||||||
 | 
					    >>> ws.close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Parameters
 | 
				
			||||||
 | 
					    ----------
 | 
				
			||||||
 | 
					    get_mask_key: func
 | 
				
			||||||
 | 
					        A callable function to get new mask keys, see the
 | 
				
			||||||
 | 
					        WebSocket.set_mask_key's docstring for more information.
 | 
				
			||||||
 | 
					    sockopt: tuple
 | 
				
			||||||
 | 
					        Values for socket.setsockopt.
 | 
				
			||||||
 | 
					        sockopt must be tuple and each element is argument of sock.setsockopt.
 | 
				
			||||||
 | 
					    sslopt: dict
 | 
				
			||||||
 | 
					        Optional dict object for ssl socket options. See FAQ for details.
 | 
				
			||||||
 | 
					    fire_cont_frame: bool
 | 
				
			||||||
 | 
					        Fire recv event for each cont frame. Default is False.
 | 
				
			||||||
 | 
					    enable_multithread: bool
 | 
				
			||||||
 | 
					        If set to True, lock send method.
 | 
				
			||||||
 | 
					    skip_utf8_validation: bool
 | 
				
			||||||
 | 
					        Skip utf8 validation.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        get_mask_key=None,
 | 
				
			||||||
 | 
					        sockopt=None,
 | 
				
			||||||
 | 
					        sslopt=None,
 | 
				
			||||||
 | 
					        fire_cont_frame: bool = False,
 | 
				
			||||||
 | 
					        enable_multithread: bool = True,
 | 
				
			||||||
 | 
					        skip_utf8_validation: bool = False,
 | 
				
			||||||
 | 
					        **_,
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Initialize WebSocket object.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Parameters
 | 
				
			||||||
 | 
					        ----------
 | 
				
			||||||
 | 
					        sslopt: dict
 | 
				
			||||||
 | 
					            Optional dict object for ssl socket options. See FAQ for details.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        self.sock_opt = sock_opt(sockopt, sslopt)
 | 
				
			||||||
 | 
					        self.handshake_response = None
 | 
				
			||||||
 | 
					        self.sock: Optional[socket.socket] = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.connected = False
 | 
				
			||||||
 | 
					        self.get_mask_key = get_mask_key
 | 
				
			||||||
 | 
					        # These buffer over the build-up of a single frame.
 | 
				
			||||||
 | 
					        self.frame_buffer = frame_buffer(self._recv, skip_utf8_validation)
 | 
				
			||||||
 | 
					        self.cont_frame = continuous_frame(fire_cont_frame, skip_utf8_validation)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if enable_multithread:
 | 
				
			||||||
 | 
					            self.lock = threading.Lock()
 | 
				
			||||||
 | 
					            self.readlock = threading.Lock()
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            self.lock = NoLock()
 | 
				
			||||||
 | 
					            self.readlock = NoLock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __iter__(self):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Allow iteration over websocket, implying sequential `recv` executions.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        while True:
 | 
				
			||||||
 | 
					            yield self.recv()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __next__(self):
 | 
				
			||||||
 | 
					        return self.recv()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def next(self):
 | 
				
			||||||
 | 
					        return self.__next__()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def fileno(self):
 | 
				
			||||||
 | 
					        return self.sock.fileno()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def set_mask_key(self, func):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Set function to create mask key. You can customize mask key generator.
 | 
				
			||||||
 | 
					        Mainly, this is for testing purpose.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Parameters
 | 
				
			||||||
 | 
					        ----------
 | 
				
			||||||
 | 
					        func: func
 | 
				
			||||||
 | 
					            callable object. the func takes 1 argument as integer.
 | 
				
			||||||
 | 
					            The argument means length of mask key.
 | 
				
			||||||
 | 
					            This func must return string(byte array),
 | 
				
			||||||
 | 
					            which length is argument specified.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        self.get_mask_key = func
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def gettimeout(self) -> Union[float, int, None]:
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Get the websocket timeout (in seconds) as an int or float
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Returns
 | 
				
			||||||
 | 
					        ----------
 | 
				
			||||||
 | 
					        timeout: int or float
 | 
				
			||||||
 | 
					             returns timeout value (in seconds). This value could be either float/integer.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        return self.sock_opt.timeout
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def settimeout(self, timeout: Union[float, int, None]):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Set the timeout to the websocket.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Parameters
 | 
				
			||||||
 | 
					        ----------
 | 
				
			||||||
 | 
					        timeout: int or float
 | 
				
			||||||
 | 
					            timeout time (in seconds). This value could be either float/integer.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        self.sock_opt.timeout = timeout
 | 
				
			||||||
 | 
					        if self.sock:
 | 
				
			||||||
 | 
					            self.sock.settimeout(timeout)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    timeout = property(gettimeout, settimeout)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def getsubprotocol(self):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Get subprotocol
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        if self.handshake_response:
 | 
				
			||||||
 | 
					            return self.handshake_response.subprotocol
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            return None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    subprotocol = property(getsubprotocol)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def getstatus(self):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Get handshake status
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        if self.handshake_response:
 | 
				
			||||||
 | 
					            return self.handshake_response.status
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            return None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    status = property(getstatus)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def getheaders(self):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Get handshake response header
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        if self.handshake_response:
 | 
				
			||||||
 | 
					            return self.handshake_response.headers
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            return None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def is_ssl(self):
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            return isinstance(self.sock, ssl.SSLSocket)
 | 
				
			||||||
 | 
					        except:
 | 
				
			||||||
 | 
					            return False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    headers = property(getheaders)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def connect(self, url, **options):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Connect to url. url is websocket url scheme.
 | 
				
			||||||
 | 
					        ie. ws://host:port/resource
 | 
				
			||||||
 | 
					        You can customize using 'options'.
 | 
				
			||||||
 | 
					        If you set "header" list object, you can set your own custom header.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        >>> ws = WebSocket()
 | 
				
			||||||
 | 
					        >>> ws.connect("ws://echo.websocket.events",
 | 
				
			||||||
 | 
					                ...     header=["User-Agent: MyProgram",
 | 
				
			||||||
 | 
					                ...             "x-custom: header"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Parameters
 | 
				
			||||||
 | 
					        ----------
 | 
				
			||||||
 | 
					        header: list or dict
 | 
				
			||||||
 | 
					            Custom http header list or dict.
 | 
				
			||||||
 | 
					        cookie: str
 | 
				
			||||||
 | 
					            Cookie value.
 | 
				
			||||||
 | 
					        origin: str
 | 
				
			||||||
 | 
					            Custom origin url.
 | 
				
			||||||
 | 
					        connection: str
 | 
				
			||||||
 | 
					            Custom connection header value.
 | 
				
			||||||
 | 
					            Default value "Upgrade" set in _handshake.py
 | 
				
			||||||
 | 
					        suppress_origin: bool
 | 
				
			||||||
 | 
					            Suppress outputting origin header.
 | 
				
			||||||
 | 
					        host: str
 | 
				
			||||||
 | 
					            Custom host header string.
 | 
				
			||||||
 | 
					        timeout: int or float
 | 
				
			||||||
 | 
					            Socket timeout time. This value is an integer or float.
 | 
				
			||||||
 | 
					            If you set None for this value, it means "use default_timeout value"
 | 
				
			||||||
 | 
					        http_proxy_host: str
 | 
				
			||||||
 | 
					            HTTP proxy host name.
 | 
				
			||||||
 | 
					        http_proxy_port: str or int
 | 
				
			||||||
 | 
					            HTTP proxy port. Default is 80.
 | 
				
			||||||
 | 
					        http_no_proxy: list
 | 
				
			||||||
 | 
					            Whitelisted host names that don't use the proxy.
 | 
				
			||||||
 | 
					        http_proxy_auth: tuple
 | 
				
			||||||
 | 
					            HTTP proxy auth information. Tuple of username and password. Default is None.
 | 
				
			||||||
 | 
					        http_proxy_timeout: int or float
 | 
				
			||||||
 | 
					            HTTP proxy timeout, default is 60 sec as per python-socks.
 | 
				
			||||||
 | 
					        redirect_limit: int
 | 
				
			||||||
 | 
					            Number of redirects to follow.
 | 
				
			||||||
 | 
					        subprotocols: list
 | 
				
			||||||
 | 
					            List of available subprotocols. Default is None.
 | 
				
			||||||
 | 
					        socket: socket
 | 
				
			||||||
 | 
					            Pre-initialized stream socket.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        self.sock_opt.timeout = options.get("timeout", self.sock_opt.timeout)
 | 
				
			||||||
 | 
					        self.sock, addrs = connect(
 | 
				
			||||||
 | 
					            url, self.sock_opt, proxy_info(**options), options.pop("socket", None)
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            self.handshake_response = handshake(self.sock, url, *addrs, **options)
 | 
				
			||||||
 | 
					            for _ in range(options.pop("redirect_limit", 3)):
 | 
				
			||||||
 | 
					                if self.handshake_response.status in SUPPORTED_REDIRECT_STATUSES:
 | 
				
			||||||
 | 
					                    url = self.handshake_response.headers["location"]
 | 
				
			||||||
 | 
					                    self.sock.close()
 | 
				
			||||||
 | 
					                    self.sock, addrs = connect(
 | 
				
			||||||
 | 
					                        url,
 | 
				
			||||||
 | 
					                        self.sock_opt,
 | 
				
			||||||
 | 
					                        proxy_info(**options),
 | 
				
			||||||
 | 
					                        options.pop("socket", None),
 | 
				
			||||||
 | 
					                    )
 | 
				
			||||||
 | 
					                    self.handshake_response = handshake(
 | 
				
			||||||
 | 
					                        self.sock, url, *addrs, **options
 | 
				
			||||||
 | 
					                    )
 | 
				
			||||||
 | 
					            self.connected = True
 | 
				
			||||||
 | 
					        except:
 | 
				
			||||||
 | 
					            if self.sock:
 | 
				
			||||||
 | 
					                self.sock.close()
 | 
				
			||||||
 | 
					                self.sock = None
 | 
				
			||||||
 | 
					            raise
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def send(self, payload: Union[bytes, str], opcode: int = ABNF.OPCODE_TEXT) -> int:
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Send the data as string.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Parameters
 | 
				
			||||||
 | 
					        ----------
 | 
				
			||||||
 | 
					        payload: str
 | 
				
			||||||
 | 
					            Payload must be utf-8 string or unicode,
 | 
				
			||||||
 | 
					            If the opcode is OPCODE_TEXT.
 | 
				
			||||||
 | 
					            Otherwise, it must be string(byte array).
 | 
				
			||||||
 | 
					        opcode: int
 | 
				
			||||||
 | 
					            Operation code (opcode) to send.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        frame = ABNF.create_frame(payload, opcode)
 | 
				
			||||||
 | 
					        return self.send_frame(frame)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def send_text(self, text_data: str) -> int:
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Sends UTF-8 encoded text.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        return self.send(text_data, ABNF.OPCODE_TEXT)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def send_bytes(self, data: Union[bytes, bytearray]) -> int:
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Sends a sequence of bytes.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        return self.send(data, ABNF.OPCODE_BINARY)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def send_frame(self, frame) -> int:
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Send the data frame.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        >>> ws = create_connection("ws://echo.websocket.events")
 | 
				
			||||||
 | 
					        >>> frame = ABNF.create_frame("Hello", ABNF.OPCODE_TEXT)
 | 
				
			||||||
 | 
					        >>> ws.send_frame(frame)
 | 
				
			||||||
 | 
					        >>> cont_frame = ABNF.create_frame("My name is ", ABNF.OPCODE_CONT, 0)
 | 
				
			||||||
 | 
					        >>> ws.send_frame(frame)
 | 
				
			||||||
 | 
					        >>> cont_frame = ABNF.create_frame("Foo Bar", ABNF.OPCODE_CONT, 1)
 | 
				
			||||||
 | 
					        >>> ws.send_frame(frame)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Parameters
 | 
				
			||||||
 | 
					        ----------
 | 
				
			||||||
 | 
					        frame: ABNF frame
 | 
				
			||||||
 | 
					            frame data created by ABNF.create_frame
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        if self.get_mask_key:
 | 
				
			||||||
 | 
					            frame.get_mask_key = self.get_mask_key
 | 
				
			||||||
 | 
					        data = frame.format()
 | 
				
			||||||
 | 
					        length = len(data)
 | 
				
			||||||
 | 
					        if isEnabledForTrace():
 | 
				
			||||||
 | 
					            trace(f"++Sent raw: {repr(data)}")
 | 
				
			||||||
 | 
					            trace(f"++Sent decoded: {frame.__str__()}")
 | 
				
			||||||
 | 
					        with self.lock:
 | 
				
			||||||
 | 
					            while data:
 | 
				
			||||||
 | 
					                l = self._send(data)
 | 
				
			||||||
 | 
					                data = data[l:]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return length
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def send_binary(self, payload: bytes) -> int:
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Send a binary message (OPCODE_BINARY).
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Parameters
 | 
				
			||||||
 | 
					        ----------
 | 
				
			||||||
 | 
					        payload: bytes
 | 
				
			||||||
 | 
					            payload of message to send.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        return self.send(payload, ABNF.OPCODE_BINARY)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def ping(self, payload: Union[str, bytes] = ""):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Send ping data.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Parameters
 | 
				
			||||||
 | 
					        ----------
 | 
				
			||||||
 | 
					        payload: str
 | 
				
			||||||
 | 
					            data payload to send server.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        if isinstance(payload, str):
 | 
				
			||||||
 | 
					            payload = payload.encode("utf-8")
 | 
				
			||||||
 | 
					        self.send(payload, ABNF.OPCODE_PING)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def pong(self, payload: Union[str, bytes] = ""):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Send pong data.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Parameters
 | 
				
			||||||
 | 
					        ----------
 | 
				
			||||||
 | 
					        payload: str
 | 
				
			||||||
 | 
					            data payload to send server.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        if isinstance(payload, str):
 | 
				
			||||||
 | 
					            payload = payload.encode("utf-8")
 | 
				
			||||||
 | 
					        self.send(payload, ABNF.OPCODE_PONG)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def recv(self) -> Union[str, bytes]:
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Receive string data(byte array) from the server.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Returns
 | 
				
			||||||
 | 
					        ----------
 | 
				
			||||||
 | 
					        data: string (byte array) value.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        with self.readlock:
 | 
				
			||||||
 | 
					            opcode, data = self.recv_data()
 | 
				
			||||||
 | 
					        if opcode == ABNF.OPCODE_TEXT:
 | 
				
			||||||
 | 
					            data_received: Union[bytes, str] = data
 | 
				
			||||||
 | 
					            if isinstance(data_received, bytes):
 | 
				
			||||||
 | 
					                return data_received.decode("utf-8")
 | 
				
			||||||
 | 
					            elif isinstance(data_received, str):
 | 
				
			||||||
 | 
					                return data_received
 | 
				
			||||||
 | 
					        elif opcode == ABNF.OPCODE_BINARY:
 | 
				
			||||||
 | 
					            data_binary: bytes = data
 | 
				
			||||||
 | 
					            return data_binary
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            return ""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def recv_data(self, control_frame: bool = False) -> tuple:
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Receive data with operation code.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Parameters
 | 
				
			||||||
 | 
					        ----------
 | 
				
			||||||
 | 
					        control_frame: bool
 | 
				
			||||||
 | 
					            a boolean flag indicating whether to return control frame
 | 
				
			||||||
 | 
					            data, defaults to False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Returns
 | 
				
			||||||
 | 
					        -------
 | 
				
			||||||
 | 
					        opcode, frame.data: tuple
 | 
				
			||||||
 | 
					            tuple of operation code and string(byte array) value.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        opcode, frame = self.recv_data_frame(control_frame)
 | 
				
			||||||
 | 
					        return opcode, frame.data
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def recv_data_frame(self, control_frame: bool = False) -> tuple:
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Receive data with operation code.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        If a valid ping message is received, a pong response is sent.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Parameters
 | 
				
			||||||
 | 
					        ----------
 | 
				
			||||||
 | 
					        control_frame: bool
 | 
				
			||||||
 | 
					            a boolean flag indicating whether to return control frame
 | 
				
			||||||
 | 
					            data, defaults to False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Returns
 | 
				
			||||||
 | 
					        -------
 | 
				
			||||||
 | 
					        frame.opcode, frame: tuple
 | 
				
			||||||
 | 
					            tuple of operation code and string(byte array) value.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        while True:
 | 
				
			||||||
 | 
					            frame = self.recv_frame()
 | 
				
			||||||
 | 
					            if isEnabledForTrace():
 | 
				
			||||||
 | 
					                trace(f"++Rcv raw: {repr(frame.format())}")
 | 
				
			||||||
 | 
					                trace(f"++Rcv decoded: {frame.__str__()}")
 | 
				
			||||||
 | 
					            if not frame:
 | 
				
			||||||
 | 
					                # handle error:
 | 
				
			||||||
 | 
					                # 'NoneType' object has no attribute 'opcode'
 | 
				
			||||||
 | 
					                raise WebSocketProtocolException(f"Not a valid frame {frame}")
 | 
				
			||||||
 | 
					            elif frame.opcode in (
 | 
				
			||||||
 | 
					                ABNF.OPCODE_TEXT,
 | 
				
			||||||
 | 
					                ABNF.OPCODE_BINARY,
 | 
				
			||||||
 | 
					                ABNF.OPCODE_CONT,
 | 
				
			||||||
 | 
					            ):
 | 
				
			||||||
 | 
					                self.cont_frame.validate(frame)
 | 
				
			||||||
 | 
					                self.cont_frame.add(frame)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                if self.cont_frame.is_fire(frame):
 | 
				
			||||||
 | 
					                    return self.cont_frame.extract(frame)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            elif frame.opcode == ABNF.OPCODE_CLOSE:
 | 
				
			||||||
 | 
					                self.send_close()
 | 
				
			||||||
 | 
					                return frame.opcode, frame
 | 
				
			||||||
 | 
					            elif frame.opcode == ABNF.OPCODE_PING:
 | 
				
			||||||
 | 
					                if len(frame.data) < 126:
 | 
				
			||||||
 | 
					                    self.pong(frame.data)
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
 | 
					                    raise WebSocketProtocolException("Ping message is too long")
 | 
				
			||||||
 | 
					                if control_frame:
 | 
				
			||||||
 | 
					                    return frame.opcode, frame
 | 
				
			||||||
 | 
					            elif frame.opcode == ABNF.OPCODE_PONG:
 | 
				
			||||||
 | 
					                if control_frame:
 | 
				
			||||||
 | 
					                    return frame.opcode, frame
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def recv_frame(self):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Receive data as frame from server.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Returns
 | 
				
			||||||
 | 
					        -------
 | 
				
			||||||
 | 
					        self.frame_buffer.recv_frame(): ABNF frame object
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        return self.frame_buffer.recv_frame()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def send_close(self, status: int = STATUS_NORMAL, reason: bytes = b""):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Send close data to the server.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Parameters
 | 
				
			||||||
 | 
					        ----------
 | 
				
			||||||
 | 
					        status: int
 | 
				
			||||||
 | 
					            Status code to send. See STATUS_XXX.
 | 
				
			||||||
 | 
					        reason: str or bytes
 | 
				
			||||||
 | 
					            The reason to close. This must be string or UTF-8 bytes.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        if status < 0 or status >= ABNF.LENGTH_16:
 | 
				
			||||||
 | 
					            raise ValueError("code is invalid range")
 | 
				
			||||||
 | 
					        self.connected = False
 | 
				
			||||||
 | 
					        self.send(struct.pack("!H", status) + reason, ABNF.OPCODE_CLOSE)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def close(self, status: int = STATUS_NORMAL, reason: bytes = b"", timeout: int = 3):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Close Websocket object
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Parameters
 | 
				
			||||||
 | 
					        ----------
 | 
				
			||||||
 | 
					        status: int
 | 
				
			||||||
 | 
					            Status code to send. See VALID_CLOSE_STATUS in ABNF.
 | 
				
			||||||
 | 
					        reason: bytes
 | 
				
			||||||
 | 
					            The reason to close in UTF-8.
 | 
				
			||||||
 | 
					        timeout: int or float
 | 
				
			||||||
 | 
					            Timeout until receive a close frame.
 | 
				
			||||||
 | 
					            If None, it will wait forever until receive a close frame.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        if not self.connected:
 | 
				
			||||||
 | 
					            return
 | 
				
			||||||
 | 
					        if status < 0 or status >= ABNF.LENGTH_16:
 | 
				
			||||||
 | 
					            raise ValueError("code is invalid range")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            self.connected = False
 | 
				
			||||||
 | 
					            self.send(struct.pack("!H", status) + reason, ABNF.OPCODE_CLOSE)
 | 
				
			||||||
 | 
					            sock_timeout = self.sock.gettimeout()
 | 
				
			||||||
 | 
					            self.sock.settimeout(timeout)
 | 
				
			||||||
 | 
					            start_time = time.time()
 | 
				
			||||||
 | 
					            while timeout is None or time.time() - start_time < timeout:
 | 
				
			||||||
 | 
					                try:
 | 
				
			||||||
 | 
					                    frame = self.recv_frame()
 | 
				
			||||||
 | 
					                    if frame.opcode != ABNF.OPCODE_CLOSE:
 | 
				
			||||||
 | 
					                        continue
 | 
				
			||||||
 | 
					                    if isEnabledForError():
 | 
				
			||||||
 | 
					                        recv_status = struct.unpack("!H", frame.data[0:2])[0]
 | 
				
			||||||
 | 
					                        if recv_status >= 3000 and recv_status <= 4999:
 | 
				
			||||||
 | 
					                            debug(f"close status: {repr(recv_status)}")
 | 
				
			||||||
 | 
					                        elif recv_status != STATUS_NORMAL:
 | 
				
			||||||
 | 
					                            error(f"close status: {repr(recv_status)}")
 | 
				
			||||||
 | 
					                    break
 | 
				
			||||||
 | 
					                except:
 | 
				
			||||||
 | 
					                    break
 | 
				
			||||||
 | 
					            self.sock.settimeout(sock_timeout)
 | 
				
			||||||
 | 
					            self.sock.shutdown(socket.SHUT_RDWR)
 | 
				
			||||||
 | 
					        except:
 | 
				
			||||||
 | 
					            pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.shutdown()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def abort(self):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Low-level asynchronous abort, wakes up other threads that are waiting in recv_*
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        if self.connected:
 | 
				
			||||||
 | 
					            self.sock.shutdown(socket.SHUT_RDWR)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def shutdown(self):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        close socket, immediately.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        if self.sock:
 | 
				
			||||||
 | 
					            self.sock.close()
 | 
				
			||||||
 | 
					            self.sock = None
 | 
				
			||||||
 | 
					            self.connected = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _send(self, data: Union[str, bytes]):
 | 
				
			||||||
 | 
					        return send(self.sock, data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _recv(self, bufsize):
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            return recv(self.sock, bufsize)
 | 
				
			||||||
 | 
					        except WebSocketConnectionClosedException:
 | 
				
			||||||
 | 
					            if self.sock:
 | 
				
			||||||
 | 
					                self.sock.close()
 | 
				
			||||||
 | 
					            self.sock = None
 | 
				
			||||||
 | 
					            self.connected = False
 | 
				
			||||||
 | 
					            raise
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def create_connection(url: str, timeout=None, class_=WebSocket, **options):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Connect to url and return websocket object.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Connect to url and return the WebSocket object.
 | 
				
			||||||
 | 
					    Passing optional timeout parameter will set the timeout on the socket.
 | 
				
			||||||
 | 
					    If no timeout is supplied,
 | 
				
			||||||
 | 
					    the global default timeout setting returned by getdefaulttimeout() is used.
 | 
				
			||||||
 | 
					    You can customize using 'options'.
 | 
				
			||||||
 | 
					    If you set "header" list object, you can set your own custom header.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    >>> conn = create_connection("ws://echo.websocket.events",
 | 
				
			||||||
 | 
					         ...     header=["User-Agent: MyProgram",
 | 
				
			||||||
 | 
					         ...             "x-custom: header"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Parameters
 | 
				
			||||||
 | 
					    ----------
 | 
				
			||||||
 | 
					    class_: class
 | 
				
			||||||
 | 
					        class to instantiate when creating the connection. It has to implement
 | 
				
			||||||
 | 
					        settimeout and connect. It's __init__ should be compatible with
 | 
				
			||||||
 | 
					        WebSocket.__init__, i.e. accept all of it's kwargs.
 | 
				
			||||||
 | 
					    header: list or dict
 | 
				
			||||||
 | 
					        custom http header list or dict.
 | 
				
			||||||
 | 
					    cookie: str
 | 
				
			||||||
 | 
					        Cookie value.
 | 
				
			||||||
 | 
					    origin: str
 | 
				
			||||||
 | 
					        custom origin url.
 | 
				
			||||||
 | 
					    suppress_origin: bool
 | 
				
			||||||
 | 
					        suppress outputting origin header.
 | 
				
			||||||
 | 
					    host: str
 | 
				
			||||||
 | 
					        custom host header string.
 | 
				
			||||||
 | 
					    timeout: int or float
 | 
				
			||||||
 | 
					        socket timeout time. This value could be either float/integer.
 | 
				
			||||||
 | 
					        If set to None, it uses the default_timeout value.
 | 
				
			||||||
 | 
					    http_proxy_host: str
 | 
				
			||||||
 | 
					        HTTP proxy host name.
 | 
				
			||||||
 | 
					    http_proxy_port: str or int
 | 
				
			||||||
 | 
					        HTTP proxy port. If not set, set to 80.
 | 
				
			||||||
 | 
					    http_no_proxy: list
 | 
				
			||||||
 | 
					        Whitelisted host names that don't use the proxy.
 | 
				
			||||||
 | 
					    http_proxy_auth: tuple
 | 
				
			||||||
 | 
					        HTTP proxy auth information. tuple of username and password. Default is None.
 | 
				
			||||||
 | 
					    http_proxy_timeout: int or float
 | 
				
			||||||
 | 
					        HTTP proxy timeout, default is 60 sec as per python-socks.
 | 
				
			||||||
 | 
					    enable_multithread: bool
 | 
				
			||||||
 | 
					        Enable lock for multithread.
 | 
				
			||||||
 | 
					    redirect_limit: int
 | 
				
			||||||
 | 
					        Number of redirects to follow.
 | 
				
			||||||
 | 
					    sockopt: tuple
 | 
				
			||||||
 | 
					        Values for socket.setsockopt.
 | 
				
			||||||
 | 
					        sockopt must be a tuple and each element is an argument of sock.setsockopt.
 | 
				
			||||||
 | 
					    sslopt: dict
 | 
				
			||||||
 | 
					        Optional dict object for ssl socket options. See FAQ for details.
 | 
				
			||||||
 | 
					    subprotocols: list
 | 
				
			||||||
 | 
					        List of available subprotocols. Default is None.
 | 
				
			||||||
 | 
					    skip_utf8_validation: bool
 | 
				
			||||||
 | 
					        Skip utf8 validation.
 | 
				
			||||||
 | 
					    socket: socket
 | 
				
			||||||
 | 
					        Pre-initialized stream socket.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    sockopt = options.pop("sockopt", [])
 | 
				
			||||||
 | 
					    sslopt = options.pop("sslopt", {})
 | 
				
			||||||
 | 
					    fire_cont_frame = options.pop("fire_cont_frame", False)
 | 
				
			||||||
 | 
					    enable_multithread = options.pop("enable_multithread", True)
 | 
				
			||||||
 | 
					    skip_utf8_validation = options.pop("skip_utf8_validation", False)
 | 
				
			||||||
 | 
					    websock = class_(
 | 
				
			||||||
 | 
					        sockopt=sockopt,
 | 
				
			||||||
 | 
					        sslopt=sslopt,
 | 
				
			||||||
 | 
					        fire_cont_frame=fire_cont_frame,
 | 
				
			||||||
 | 
					        enable_multithread=enable_multithread,
 | 
				
			||||||
 | 
					        skip_utf8_validation=skip_utf8_validation,
 | 
				
			||||||
 | 
					        **options,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    websock.settimeout(timeout if timeout is not None else getdefaulttimeout())
 | 
				
			||||||
 | 
					    websock.connect(url, **options)
 | 
				
			||||||
 | 
					    return websock
 | 
				
			||||||
							
								
								
									
										94
									
								
								src/libs/websocket/_exceptions.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										94
									
								
								src/libs/websocket/_exceptions.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,94 @@
 | 
				
			|||||||
 | 
					"""
 | 
				
			||||||
 | 
					_exceptions.py
 | 
				
			||||||
 | 
					websocket - WebSocket client library for Python
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Copyright 2024 engn33r
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					You may obtain a copy of the License at
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					limitations under the License.
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class WebSocketException(Exception):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    WebSocket exception class.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class WebSocketProtocolException(WebSocketException):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    If the WebSocket protocol is invalid, this exception will be raised.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class WebSocketPayloadException(WebSocketException):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    If the WebSocket payload is invalid, this exception will be raised.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class WebSocketConnectionClosedException(WebSocketException):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    If remote host closed the connection or some network error happened,
 | 
				
			||||||
 | 
					    this exception will be raised.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class WebSocketTimeoutException(WebSocketException):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    WebSocketTimeoutException will be raised at socket timeout during read/write data.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class WebSocketProxyException(WebSocketException):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    WebSocketProxyException will be raised when proxy error occurred.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class WebSocketBadStatusException(WebSocketException):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    WebSocketBadStatusException will be raised when we get bad handshake status code.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        message: str,
 | 
				
			||||||
 | 
					        status_code: int,
 | 
				
			||||||
 | 
					        status_message=None,
 | 
				
			||||||
 | 
					        resp_headers=None,
 | 
				
			||||||
 | 
					        resp_body=None,
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        super().__init__(message)
 | 
				
			||||||
 | 
					        self.status_code = status_code
 | 
				
			||||||
 | 
					        self.resp_headers = resp_headers
 | 
				
			||||||
 | 
					        self.resp_body = resp_body
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class WebSocketAddressException(WebSocketException):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    If the websocket address info cannot be found, this exception will be raised.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    pass
 | 
				
			||||||
							
								
								
									
										203
									
								
								src/libs/websocket/_handshake.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										203
									
								
								src/libs/websocket/_handshake.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,203 @@
 | 
				
			|||||||
 | 
					"""
 | 
				
			||||||
 | 
					_handshake.py
 | 
				
			||||||
 | 
					websocket - WebSocket client library for Python
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Copyright 2024 engn33r
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					You may obtain a copy of the License at
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					limitations under the License.
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import hashlib
 | 
				
			||||||
 | 
					import hmac
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					from base64 import encodebytes as base64encode
 | 
				
			||||||
 | 
					from http import HTTPStatus
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from ._cookiejar import SimpleCookieJar
 | 
				
			||||||
 | 
					from ._exceptions import WebSocketException, WebSocketBadStatusException
 | 
				
			||||||
 | 
					from ._http import read_headers
 | 
				
			||||||
 | 
					from ._logging import dump, error
 | 
				
			||||||
 | 
					from ._socket import send
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					__all__ = ["handshake_response", "handshake", "SUPPORTED_REDIRECT_STATUSES"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# websocket supported version.
 | 
				
			||||||
 | 
					VERSION = 13
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					SUPPORTED_REDIRECT_STATUSES = (
 | 
				
			||||||
 | 
					    HTTPStatus.MOVED_PERMANENTLY,
 | 
				
			||||||
 | 
					    HTTPStatus.FOUND,
 | 
				
			||||||
 | 
					    HTTPStatus.SEE_OTHER,
 | 
				
			||||||
 | 
					    HTTPStatus.TEMPORARY_REDIRECT,
 | 
				
			||||||
 | 
					    HTTPStatus.PERMANENT_REDIRECT,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					SUCCESS_STATUSES = SUPPORTED_REDIRECT_STATUSES + (HTTPStatus.SWITCHING_PROTOCOLS,)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					CookieJar = SimpleCookieJar()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class handshake_response:
 | 
				
			||||||
 | 
					    def __init__(self, status: int, headers: dict, subprotocol):
 | 
				
			||||||
 | 
					        self.status = status
 | 
				
			||||||
 | 
					        self.headers = headers
 | 
				
			||||||
 | 
					        self.subprotocol = subprotocol
 | 
				
			||||||
 | 
					        CookieJar.add(headers.get("set-cookie"))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def handshake(
 | 
				
			||||||
 | 
					    sock, url: str, hostname: str, port: int, resource: str, **options
 | 
				
			||||||
 | 
					) -> handshake_response:
 | 
				
			||||||
 | 
					    headers, key = _get_handshake_headers(resource, url, hostname, port, options)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    header_str = "\r\n".join(headers)
 | 
				
			||||||
 | 
					    send(sock, header_str)
 | 
				
			||||||
 | 
					    dump("request header", header_str)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    status, resp = _get_resp_headers(sock)
 | 
				
			||||||
 | 
					    if status in SUPPORTED_REDIRECT_STATUSES:
 | 
				
			||||||
 | 
					        return handshake_response(status, resp, None)
 | 
				
			||||||
 | 
					    success, subproto = _validate(resp, key, options.get("subprotocols"))
 | 
				
			||||||
 | 
					    if not success:
 | 
				
			||||||
 | 
					        raise WebSocketException("Invalid WebSocket Header")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return handshake_response(status, resp, subproto)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _pack_hostname(hostname: str) -> str:
 | 
				
			||||||
 | 
					    # IPv6 address
 | 
				
			||||||
 | 
					    if ":" in hostname:
 | 
				
			||||||
 | 
					        return f"[{hostname}]"
 | 
				
			||||||
 | 
					    return hostname
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _get_handshake_headers(
 | 
				
			||||||
 | 
					    resource: str, url: str, host: str, port: int, options: dict
 | 
				
			||||||
 | 
					) -> tuple:
 | 
				
			||||||
 | 
					    headers = [f"GET {resource} HTTP/1.1", "Upgrade: websocket"]
 | 
				
			||||||
 | 
					    if port in [80, 443]:
 | 
				
			||||||
 | 
					        hostport = _pack_hostname(host)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        hostport = f"{_pack_hostname(host)}:{port}"
 | 
				
			||||||
 | 
					    if options.get("host"):
 | 
				
			||||||
 | 
					        headers.append(f'Host: {options["host"]}')
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        headers.append(f"Host: {hostport}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # scheme indicates whether http or https is used in Origin
 | 
				
			||||||
 | 
					    # The same approach is used in parse_url of _url.py to set default port
 | 
				
			||||||
 | 
					    scheme, url = url.split(":", 1)
 | 
				
			||||||
 | 
					    if not options.get("suppress_origin"):
 | 
				
			||||||
 | 
					        if "origin" in options and options["origin"] is not None:
 | 
				
			||||||
 | 
					            headers.append(f'Origin: {options["origin"]}')
 | 
				
			||||||
 | 
					        elif scheme == "wss":
 | 
				
			||||||
 | 
					            headers.append(f"Origin: https://{hostport}")
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            headers.append(f"Origin: http://{hostport}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    key = _create_sec_websocket_key()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Append Sec-WebSocket-Key & Sec-WebSocket-Version if not manually specified
 | 
				
			||||||
 | 
					    if not options.get("header") or "Sec-WebSocket-Key" not in options["header"]:
 | 
				
			||||||
 | 
					        headers.append(f"Sec-WebSocket-Key: {key}")
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        key = options["header"]["Sec-WebSocket-Key"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if not options.get("header") or "Sec-WebSocket-Version" not in options["header"]:
 | 
				
			||||||
 | 
					        headers.append(f"Sec-WebSocket-Version: {VERSION}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if not options.get("connection"):
 | 
				
			||||||
 | 
					        headers.append("Connection: Upgrade")
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        headers.append(options["connection"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if subprotocols := options.get("subprotocols"):
 | 
				
			||||||
 | 
					        headers.append(f'Sec-WebSocket-Protocol: {",".join(subprotocols)}')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if header := options.get("header"):
 | 
				
			||||||
 | 
					        if isinstance(header, dict):
 | 
				
			||||||
 | 
					            header = [": ".join([k, v]) for k, v in header.items() if v is not None]
 | 
				
			||||||
 | 
					        headers.extend(header)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    server_cookie = CookieJar.get(host)
 | 
				
			||||||
 | 
					    client_cookie = options.get("cookie", None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if cookie := "; ".join(filter(None, [server_cookie, client_cookie])):
 | 
				
			||||||
 | 
					        headers.append(f"Cookie: {cookie}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    headers.extend(("", ""))
 | 
				
			||||||
 | 
					    return headers, key
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _get_resp_headers(sock, success_statuses: tuple = SUCCESS_STATUSES) -> tuple:
 | 
				
			||||||
 | 
					    status, resp_headers, status_message = read_headers(sock)
 | 
				
			||||||
 | 
					    if status not in success_statuses:
 | 
				
			||||||
 | 
					        content_len = resp_headers.get("content-length")
 | 
				
			||||||
 | 
					        if content_len:
 | 
				
			||||||
 | 
					            response_body = sock.recv(
 | 
				
			||||||
 | 
					                int(content_len)
 | 
				
			||||||
 | 
					            )  # read the body of the HTTP error message response and include it in the exception
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            response_body = None
 | 
				
			||||||
 | 
					        raise WebSocketBadStatusException(
 | 
				
			||||||
 | 
					            f"Handshake status {status} {status_message} -+-+- {resp_headers} -+-+- {response_body}",
 | 
				
			||||||
 | 
					            status,
 | 
				
			||||||
 | 
					            status_message,
 | 
				
			||||||
 | 
					            resp_headers,
 | 
				
			||||||
 | 
					            response_body,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					    return status, resp_headers
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					_HEADERS_TO_CHECK = {
 | 
				
			||||||
 | 
					    "upgrade": "websocket",
 | 
				
			||||||
 | 
					    "connection": "upgrade",
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _validate(headers, key: str, subprotocols) -> tuple:
 | 
				
			||||||
 | 
					    subproto = None
 | 
				
			||||||
 | 
					    for k, v in _HEADERS_TO_CHECK.items():
 | 
				
			||||||
 | 
					        r = headers.get(k, None)
 | 
				
			||||||
 | 
					        if not r:
 | 
				
			||||||
 | 
					            return False, None
 | 
				
			||||||
 | 
					        r = [x.strip().lower() for x in r.split(",")]
 | 
				
			||||||
 | 
					        if v not in r:
 | 
				
			||||||
 | 
					            return False, None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if subprotocols:
 | 
				
			||||||
 | 
					        subproto = headers.get("sec-websocket-protocol", None)
 | 
				
			||||||
 | 
					        if not subproto or subproto.lower() not in [s.lower() for s in subprotocols]:
 | 
				
			||||||
 | 
					            error(f"Invalid subprotocol: {subprotocols}")
 | 
				
			||||||
 | 
					            return False, None
 | 
				
			||||||
 | 
					        subproto = subproto.lower()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    result = headers.get("sec-websocket-accept", None)
 | 
				
			||||||
 | 
					    if not result:
 | 
				
			||||||
 | 
					        return False, None
 | 
				
			||||||
 | 
					    result = result.lower()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if isinstance(result, str):
 | 
				
			||||||
 | 
					        result = result.encode("utf-8")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    value = f"{key}258EAFA5-E914-47DA-95CA-C5AB0DC85B11".encode("utf-8")
 | 
				
			||||||
 | 
					    hashed = base64encode(hashlib.sha1(value).digest()).strip().lower()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if hmac.compare_digest(hashed, result):
 | 
				
			||||||
 | 
					        return True, subproto
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        return False, None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _create_sec_websocket_key() -> str:
 | 
				
			||||||
 | 
					    randomness = os.urandom(16)
 | 
				
			||||||
 | 
					    return base64encode(randomness).decode("utf-8").strip()
 | 
				
			||||||
							
								
								
									
										374
									
								
								src/libs/websocket/_http.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										374
									
								
								src/libs/websocket/_http.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,374 @@
 | 
				
			|||||||
 | 
					"""
 | 
				
			||||||
 | 
					_http.py
 | 
				
			||||||
 | 
					websocket - WebSocket client library for Python
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Copyright 2024 engn33r
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					You may obtain a copy of the License at
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					limitations under the License.
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import errno
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					import socket
 | 
				
			||||||
 | 
					from base64 import encodebytes as base64encode
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from ._exceptions import (
 | 
				
			||||||
 | 
					    WebSocketAddressException,
 | 
				
			||||||
 | 
					    WebSocketException,
 | 
				
			||||||
 | 
					    WebSocketProxyException,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					from ._logging import debug, dump, trace
 | 
				
			||||||
 | 
					from ._socket import DEFAULT_SOCKET_OPTION, recv_line, send
 | 
				
			||||||
 | 
					from ._ssl_compat import HAVE_SSL, ssl
 | 
				
			||||||
 | 
					from ._url import get_proxy_info, parse_url
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					__all__ = ["proxy_info", "connect", "read_headers"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					try:
 | 
				
			||||||
 | 
					    from python_socks._errors import *
 | 
				
			||||||
 | 
					    from python_socks._types import ProxyType
 | 
				
			||||||
 | 
					    from python_socks.sync import Proxy
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    HAVE_PYTHON_SOCKS = True
 | 
				
			||||||
 | 
					except:
 | 
				
			||||||
 | 
					    HAVE_PYTHON_SOCKS = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    class ProxyError(Exception):
 | 
				
			||||||
 | 
					        pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    class ProxyTimeoutError(Exception):
 | 
				
			||||||
 | 
					        pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    class ProxyConnectionError(Exception):
 | 
				
			||||||
 | 
					        pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class proxy_info:
 | 
				
			||||||
 | 
					    def __init__(self, **options):
 | 
				
			||||||
 | 
					        self.proxy_host = options.get("http_proxy_host", None)
 | 
				
			||||||
 | 
					        if self.proxy_host:
 | 
				
			||||||
 | 
					            self.proxy_port = options.get("http_proxy_port", 0)
 | 
				
			||||||
 | 
					            self.auth = options.get("http_proxy_auth", None)
 | 
				
			||||||
 | 
					            self.no_proxy = options.get("http_no_proxy", None)
 | 
				
			||||||
 | 
					            self.proxy_protocol = options.get("proxy_type", "http")
 | 
				
			||||||
 | 
					            # Note: If timeout not specified, default python-socks timeout is 60 seconds
 | 
				
			||||||
 | 
					            self.proxy_timeout = options.get("http_proxy_timeout", None)
 | 
				
			||||||
 | 
					            if self.proxy_protocol not in [
 | 
				
			||||||
 | 
					                "http",
 | 
				
			||||||
 | 
					                "socks4",
 | 
				
			||||||
 | 
					                "socks4a",
 | 
				
			||||||
 | 
					                "socks5",
 | 
				
			||||||
 | 
					                "socks5h",
 | 
				
			||||||
 | 
					            ]:
 | 
				
			||||||
 | 
					                raise ProxyError(
 | 
				
			||||||
 | 
					                    "Only http, socks4, socks5 proxy protocols are supported"
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            self.proxy_port = 0
 | 
				
			||||||
 | 
					            self.auth = None
 | 
				
			||||||
 | 
					            self.no_proxy = None
 | 
				
			||||||
 | 
					            self.proxy_protocol = "http"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _start_proxied_socket(url: str, options, proxy) -> tuple:
 | 
				
			||||||
 | 
					    if not HAVE_PYTHON_SOCKS:
 | 
				
			||||||
 | 
					        raise WebSocketException(
 | 
				
			||||||
 | 
					            "Python Socks is needed for SOCKS proxying but is not available"
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    hostname, port, resource, is_secure = parse_url(url)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if proxy.proxy_protocol == "socks4":
 | 
				
			||||||
 | 
					        rdns = False
 | 
				
			||||||
 | 
					        proxy_type = ProxyType.SOCKS4
 | 
				
			||||||
 | 
					    # socks4a sends DNS through proxy
 | 
				
			||||||
 | 
					    elif proxy.proxy_protocol == "socks4a":
 | 
				
			||||||
 | 
					        rdns = True
 | 
				
			||||||
 | 
					        proxy_type = ProxyType.SOCKS4
 | 
				
			||||||
 | 
					    elif proxy.proxy_protocol == "socks5":
 | 
				
			||||||
 | 
					        rdns = False
 | 
				
			||||||
 | 
					        proxy_type = ProxyType.SOCKS5
 | 
				
			||||||
 | 
					    # socks5h sends DNS through proxy
 | 
				
			||||||
 | 
					    elif proxy.proxy_protocol == "socks5h":
 | 
				
			||||||
 | 
					        rdns = True
 | 
				
			||||||
 | 
					        proxy_type = ProxyType.SOCKS5
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    ws_proxy = Proxy.create(
 | 
				
			||||||
 | 
					        proxy_type=proxy_type,
 | 
				
			||||||
 | 
					        host=proxy.proxy_host,
 | 
				
			||||||
 | 
					        port=int(proxy.proxy_port),
 | 
				
			||||||
 | 
					        username=proxy.auth[0] if proxy.auth else None,
 | 
				
			||||||
 | 
					        password=proxy.auth[1] if proxy.auth else None,
 | 
				
			||||||
 | 
					        rdns=rdns,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    sock = ws_proxy.connect(hostname, port, timeout=proxy.proxy_timeout)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if is_secure:
 | 
				
			||||||
 | 
					        if HAVE_SSL:
 | 
				
			||||||
 | 
					            sock = _ssl_socket(sock, options.sslopt, hostname)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            raise WebSocketException("SSL not available.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return sock, (hostname, port, resource)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def connect(url: str, options, proxy, socket):
 | 
				
			||||||
 | 
					    # Use _start_proxied_socket() only for socks4 or socks5 proxy
 | 
				
			||||||
 | 
					    # Use _tunnel() for http proxy
 | 
				
			||||||
 | 
					    # TODO: Use python-socks for http protocol also, to standardize flow
 | 
				
			||||||
 | 
					    if proxy.proxy_host and not socket and proxy.proxy_protocol != "http":
 | 
				
			||||||
 | 
					        return _start_proxied_socket(url, options, proxy)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    hostname, port_from_url, resource, is_secure = parse_url(url)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if socket:
 | 
				
			||||||
 | 
					        return socket, (hostname, port_from_url, resource)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    addrinfo_list, need_tunnel, auth = _get_addrinfo_list(
 | 
				
			||||||
 | 
					        hostname, port_from_url, is_secure, proxy
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    if not addrinfo_list:
 | 
				
			||||||
 | 
					        raise WebSocketException(f"Host not found.: {hostname}:{port_from_url}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    sock = None
 | 
				
			||||||
 | 
					    try:
 | 
				
			||||||
 | 
					        sock = _open_socket(addrinfo_list, options.sockopt, options.timeout)
 | 
				
			||||||
 | 
					        if need_tunnel:
 | 
				
			||||||
 | 
					            sock = _tunnel(sock, hostname, port_from_url, auth)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if is_secure:
 | 
				
			||||||
 | 
					            if HAVE_SSL:
 | 
				
			||||||
 | 
					                sock = _ssl_socket(sock, options.sslopt, hostname)
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                raise WebSocketException("SSL not available.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return sock, (hostname, port_from_url, resource)
 | 
				
			||||||
 | 
					    except:
 | 
				
			||||||
 | 
					        if sock:
 | 
				
			||||||
 | 
					            sock.close()
 | 
				
			||||||
 | 
					        raise
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _get_addrinfo_list(hostname, port: int, is_secure: bool, proxy) -> tuple:
 | 
				
			||||||
 | 
					    phost, pport, pauth = get_proxy_info(
 | 
				
			||||||
 | 
					        hostname,
 | 
				
			||||||
 | 
					        is_secure,
 | 
				
			||||||
 | 
					        proxy.proxy_host,
 | 
				
			||||||
 | 
					        proxy.proxy_port,
 | 
				
			||||||
 | 
					        proxy.auth,
 | 
				
			||||||
 | 
					        proxy.no_proxy,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    try:
 | 
				
			||||||
 | 
					        # when running on windows 10, getaddrinfo without socktype returns a socktype 0.
 | 
				
			||||||
 | 
					        # This generates an error exception: `_on_error: exception Socket type must be stream or datagram, not 0`
 | 
				
			||||||
 | 
					        # or `OSError: [Errno 22] Invalid argument` when creating socket. Force the socket type to SOCK_STREAM.
 | 
				
			||||||
 | 
					        if not phost:
 | 
				
			||||||
 | 
					            addrinfo_list = socket.getaddrinfo(
 | 
				
			||||||
 | 
					                hostname, port, 0, socket.SOCK_STREAM, socket.SOL_TCP
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            return addrinfo_list, False, None
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            pport = pport and pport or 80
 | 
				
			||||||
 | 
					            # when running on windows 10, the getaddrinfo used above
 | 
				
			||||||
 | 
					            # returns a socktype 0. This generates an error exception:
 | 
				
			||||||
 | 
					            # _on_error: exception Socket type must be stream or datagram, not 0
 | 
				
			||||||
 | 
					            # Force the socket type to SOCK_STREAM
 | 
				
			||||||
 | 
					            addrinfo_list = socket.getaddrinfo(
 | 
				
			||||||
 | 
					                phost, pport, 0, socket.SOCK_STREAM, socket.SOL_TCP
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            return addrinfo_list, True, pauth
 | 
				
			||||||
 | 
					    except socket.gaierror as e:
 | 
				
			||||||
 | 
					        raise WebSocketAddressException(e)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _open_socket(addrinfo_list, sockopt, timeout):
 | 
				
			||||||
 | 
					    err = None
 | 
				
			||||||
 | 
					    for addrinfo in addrinfo_list:
 | 
				
			||||||
 | 
					        family, socktype, proto = addrinfo[:3]
 | 
				
			||||||
 | 
					        sock = socket.socket(family, socktype, proto)
 | 
				
			||||||
 | 
					        sock.settimeout(timeout)
 | 
				
			||||||
 | 
					        for opts in DEFAULT_SOCKET_OPTION:
 | 
				
			||||||
 | 
					            sock.setsockopt(*opts)
 | 
				
			||||||
 | 
					        for opts in sockopt:
 | 
				
			||||||
 | 
					            sock.setsockopt(*opts)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        address = addrinfo[4]
 | 
				
			||||||
 | 
					        err = None
 | 
				
			||||||
 | 
					        while not err:
 | 
				
			||||||
 | 
					            try:
 | 
				
			||||||
 | 
					                sock.connect(address)
 | 
				
			||||||
 | 
					            except socket.error as error:
 | 
				
			||||||
 | 
					                sock.close()
 | 
				
			||||||
 | 
					                error.remote_ip = str(address[0])
 | 
				
			||||||
 | 
					                try:
 | 
				
			||||||
 | 
					                    eConnRefused = (
 | 
				
			||||||
 | 
					                        errno.ECONNREFUSED,
 | 
				
			||||||
 | 
					                        errno.WSAECONNREFUSED,
 | 
				
			||||||
 | 
					                        errno.ENETUNREACH,
 | 
				
			||||||
 | 
					                    )
 | 
				
			||||||
 | 
					                except AttributeError:
 | 
				
			||||||
 | 
					                    eConnRefused = (errno.ECONNREFUSED, errno.ENETUNREACH)
 | 
				
			||||||
 | 
					                if error.errno not in eConnRefused:
 | 
				
			||||||
 | 
					                    raise error
 | 
				
			||||||
 | 
					                err = error
 | 
				
			||||||
 | 
					                continue
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                break
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            continue
 | 
				
			||||||
 | 
					        break
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        if err:
 | 
				
			||||||
 | 
					            raise err
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return sock
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _wrap_sni_socket(sock: socket.socket, sslopt: dict, hostname, check_hostname):
 | 
				
			||||||
 | 
					    context = sslopt.get("context", None)
 | 
				
			||||||
 | 
					    if not context:
 | 
				
			||||||
 | 
					        context = ssl.SSLContext(sslopt.get("ssl_version", ssl.PROTOCOL_TLS_CLIENT))
 | 
				
			||||||
 | 
					        # Non default context need to manually enable SSLKEYLOGFILE support by setting the keylog_filename attribute.
 | 
				
			||||||
 | 
					        # For more details see also:
 | 
				
			||||||
 | 
					        # * https://docs.python.org/3.8/library/ssl.html?highlight=sslkeylogfile#context-creation
 | 
				
			||||||
 | 
					        # * https://docs.python.org/3.8/library/ssl.html?highlight=sslkeylogfile#ssl.SSLContext.keylog_filename
 | 
				
			||||||
 | 
					        context.keylog_filename = os.environ.get("SSLKEYLOGFILE", None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if sslopt.get("cert_reqs", ssl.CERT_NONE) != ssl.CERT_NONE:
 | 
				
			||||||
 | 
					            cafile = sslopt.get("ca_certs", None)
 | 
				
			||||||
 | 
					            capath = sslopt.get("ca_cert_path", None)
 | 
				
			||||||
 | 
					            if cafile or capath:
 | 
				
			||||||
 | 
					                context.load_verify_locations(cafile=cafile, capath=capath)
 | 
				
			||||||
 | 
					            elif hasattr(context, "load_default_certs"):
 | 
				
			||||||
 | 
					                context.load_default_certs(ssl.Purpose.SERVER_AUTH)
 | 
				
			||||||
 | 
					        if sslopt.get("certfile", None):
 | 
				
			||||||
 | 
					            context.load_cert_chain(
 | 
				
			||||||
 | 
					                sslopt["certfile"],
 | 
				
			||||||
 | 
					                sslopt.get("keyfile", None),
 | 
				
			||||||
 | 
					                sslopt.get("password", None),
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Python 3.10 switch to PROTOCOL_TLS_CLIENT defaults to "cert_reqs = ssl.CERT_REQUIRED" and "check_hostname = True"
 | 
				
			||||||
 | 
					        # If both disabled, set check_hostname before verify_mode
 | 
				
			||||||
 | 
					        # see https://github.com/liris/websocket-client/commit/b96a2e8fa765753e82eea531adb19716b52ca3ca#commitcomment-10803153
 | 
				
			||||||
 | 
					        if sslopt.get("cert_reqs", ssl.CERT_NONE) == ssl.CERT_NONE and not sslopt.get(
 | 
				
			||||||
 | 
					            "check_hostname", False
 | 
				
			||||||
 | 
					        ):
 | 
				
			||||||
 | 
					            context.check_hostname = False
 | 
				
			||||||
 | 
					            context.verify_mode = ssl.CERT_NONE
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            context.check_hostname = sslopt.get("check_hostname", True)
 | 
				
			||||||
 | 
					            context.verify_mode = sslopt.get("cert_reqs", ssl.CERT_REQUIRED)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if "ciphers" in sslopt:
 | 
				
			||||||
 | 
					            context.set_ciphers(sslopt["ciphers"])
 | 
				
			||||||
 | 
					        if "cert_chain" in sslopt:
 | 
				
			||||||
 | 
					            certfile, keyfile, password = sslopt["cert_chain"]
 | 
				
			||||||
 | 
					            context.load_cert_chain(certfile, keyfile, password)
 | 
				
			||||||
 | 
					        if "ecdh_curve" in sslopt:
 | 
				
			||||||
 | 
					            context.set_ecdh_curve(sslopt["ecdh_curve"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return context.wrap_socket(
 | 
				
			||||||
 | 
					        sock,
 | 
				
			||||||
 | 
					        do_handshake_on_connect=sslopt.get("do_handshake_on_connect", True),
 | 
				
			||||||
 | 
					        suppress_ragged_eofs=sslopt.get("suppress_ragged_eofs", True),
 | 
				
			||||||
 | 
					        server_hostname=hostname,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _ssl_socket(sock: socket.socket, user_sslopt: dict, hostname):
 | 
				
			||||||
 | 
					    sslopt: dict = {"cert_reqs": ssl.CERT_REQUIRED}
 | 
				
			||||||
 | 
					    sslopt.update(user_sslopt)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    cert_path = os.environ.get("WEBSOCKET_CLIENT_CA_BUNDLE")
 | 
				
			||||||
 | 
					    if (
 | 
				
			||||||
 | 
					        cert_path
 | 
				
			||||||
 | 
					        and os.path.isfile(cert_path)
 | 
				
			||||||
 | 
					        and user_sslopt.get("ca_certs", None) is None
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        sslopt["ca_certs"] = cert_path
 | 
				
			||||||
 | 
					    elif (
 | 
				
			||||||
 | 
					        cert_path
 | 
				
			||||||
 | 
					        and os.path.isdir(cert_path)
 | 
				
			||||||
 | 
					        and user_sslopt.get("ca_cert_path", None) is None
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        sslopt["ca_cert_path"] = cert_path
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if sslopt.get("server_hostname", None):
 | 
				
			||||||
 | 
					        hostname = sslopt["server_hostname"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    check_hostname = sslopt.get("check_hostname", True)
 | 
				
			||||||
 | 
					    sock = _wrap_sni_socket(sock, sslopt, hostname, check_hostname)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return sock
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _tunnel(sock: socket.socket, host, port: int, auth) -> socket.socket:
 | 
				
			||||||
 | 
					    debug("Connecting proxy...")
 | 
				
			||||||
 | 
					    connect_header = f"CONNECT {host}:{port} HTTP/1.1\r\n"
 | 
				
			||||||
 | 
					    connect_header += f"Host: {host}:{port}\r\n"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # TODO: support digest auth.
 | 
				
			||||||
 | 
					    if auth and auth[0]:
 | 
				
			||||||
 | 
					        auth_str = auth[0]
 | 
				
			||||||
 | 
					        if auth[1]:
 | 
				
			||||||
 | 
					            auth_str += f":{auth[1]}"
 | 
				
			||||||
 | 
					        encoded_str = base64encode(auth_str.encode()).strip().decode().replace("\n", "")
 | 
				
			||||||
 | 
					        connect_header += f"Proxy-Authorization: Basic {encoded_str}\r\n"
 | 
				
			||||||
 | 
					    connect_header += "\r\n"
 | 
				
			||||||
 | 
					    dump("request header", connect_header)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    send(sock, connect_header)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    try:
 | 
				
			||||||
 | 
					        status, _, _ = read_headers(sock)
 | 
				
			||||||
 | 
					    except Exception as e:
 | 
				
			||||||
 | 
					        raise WebSocketProxyException(str(e))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if status != 200:
 | 
				
			||||||
 | 
					        raise WebSocketProxyException(f"failed CONNECT via proxy status: {status}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return sock
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def read_headers(sock: socket.socket) -> tuple:
 | 
				
			||||||
 | 
					    status = None
 | 
				
			||||||
 | 
					    status_message = None
 | 
				
			||||||
 | 
					    headers: dict = {}
 | 
				
			||||||
 | 
					    trace("--- response header ---")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    while True:
 | 
				
			||||||
 | 
					        line = recv_line(sock)
 | 
				
			||||||
 | 
					        line = line.decode("utf-8").strip()
 | 
				
			||||||
 | 
					        if not line:
 | 
				
			||||||
 | 
					            break
 | 
				
			||||||
 | 
					        trace(line)
 | 
				
			||||||
 | 
					        if not status:
 | 
				
			||||||
 | 
					            status_info = line.split(" ", 2)
 | 
				
			||||||
 | 
					            status = int(status_info[1])
 | 
				
			||||||
 | 
					            if len(status_info) > 2:
 | 
				
			||||||
 | 
					                status_message = status_info[2]
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            kv = line.split(":", 1)
 | 
				
			||||||
 | 
					            if len(kv) != 2:
 | 
				
			||||||
 | 
					                raise WebSocketException("Invalid header")
 | 
				
			||||||
 | 
					            key, value = kv
 | 
				
			||||||
 | 
					            if key.lower() == "set-cookie" and headers.get("set-cookie"):
 | 
				
			||||||
 | 
					                headers["set-cookie"] = headers.get("set-cookie") + "; " + value.strip()
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                headers[key.lower()] = value.strip()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    trace("-----------------------")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return status, headers, status_message
 | 
				
			||||||
							
								
								
									
										106
									
								
								src/libs/websocket/_logging.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										106
									
								
								src/libs/websocket/_logging.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,106 @@
 | 
				
			|||||||
 | 
					import logging
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					_logging.py
 | 
				
			||||||
 | 
					websocket - WebSocket client library for Python
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Copyright 2024 engn33r
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					You may obtain a copy of the License at
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					limitations under the License.
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					_logger = logging.getLogger("websocket")
 | 
				
			||||||
 | 
					try:
 | 
				
			||||||
 | 
					    from logging import NullHandler
 | 
				
			||||||
 | 
					except ImportError:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    class NullHandler(logging.Handler):
 | 
				
			||||||
 | 
					        def emit(self, record) -> None:
 | 
				
			||||||
 | 
					            pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					_logger.addHandler(NullHandler())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					_traceEnabled = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					__all__ = [
 | 
				
			||||||
 | 
					    "enableTrace",
 | 
				
			||||||
 | 
					    "dump",
 | 
				
			||||||
 | 
					    "error",
 | 
				
			||||||
 | 
					    "warning",
 | 
				
			||||||
 | 
					    "debug",
 | 
				
			||||||
 | 
					    "trace",
 | 
				
			||||||
 | 
					    "isEnabledForError",
 | 
				
			||||||
 | 
					    "isEnabledForDebug",
 | 
				
			||||||
 | 
					    "isEnabledForTrace",
 | 
				
			||||||
 | 
					]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def enableTrace(
 | 
				
			||||||
 | 
					    traceable: bool,
 | 
				
			||||||
 | 
					    handler: logging.StreamHandler = logging.StreamHandler(),
 | 
				
			||||||
 | 
					    level: str = "DEBUG",
 | 
				
			||||||
 | 
					) -> None:
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Turn on/off the traceability.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Parameters
 | 
				
			||||||
 | 
					    ----------
 | 
				
			||||||
 | 
					    traceable: bool
 | 
				
			||||||
 | 
					        If set to True, traceability is enabled.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    global _traceEnabled
 | 
				
			||||||
 | 
					    _traceEnabled = traceable
 | 
				
			||||||
 | 
					    if traceable:
 | 
				
			||||||
 | 
					        _logger.addHandler(handler)
 | 
				
			||||||
 | 
					        _logger.setLevel(getattr(logging, level))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def dump(title: str, message: str) -> None:
 | 
				
			||||||
 | 
					    if _traceEnabled:
 | 
				
			||||||
 | 
					        _logger.debug(f"--- {title} ---")
 | 
				
			||||||
 | 
					        _logger.debug(message)
 | 
				
			||||||
 | 
					        _logger.debug("-----------------------")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def error(msg: str) -> None:
 | 
				
			||||||
 | 
					    _logger.error(msg)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def warning(msg: str) -> None:
 | 
				
			||||||
 | 
					    _logger.warning(msg)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def debug(msg: str) -> None:
 | 
				
			||||||
 | 
					    _logger.debug(msg)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def info(msg: str) -> None:
 | 
				
			||||||
 | 
					    _logger.info(msg)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def trace(msg: str) -> None:
 | 
				
			||||||
 | 
					    if _traceEnabled:
 | 
				
			||||||
 | 
					        _logger.debug(msg)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def isEnabledForError() -> bool:
 | 
				
			||||||
 | 
					    return _logger.isEnabledFor(logging.ERROR)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def isEnabledForDebug() -> bool:
 | 
				
			||||||
 | 
					    return _logger.isEnabledFor(logging.DEBUG)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def isEnabledForTrace() -> bool:
 | 
				
			||||||
 | 
					    return _traceEnabled
 | 
				
			||||||
							
								
								
									
										188
									
								
								src/libs/websocket/_socket.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										188
									
								
								src/libs/websocket/_socket.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,188 @@
 | 
				
			|||||||
 | 
					import errno
 | 
				
			||||||
 | 
					import selectors
 | 
				
			||||||
 | 
					import socket
 | 
				
			||||||
 | 
					from typing import Union
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from ._exceptions import (
 | 
				
			||||||
 | 
					    WebSocketConnectionClosedException,
 | 
				
			||||||
 | 
					    WebSocketTimeoutException,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					from ._ssl_compat import SSLError, SSLWantReadError, SSLWantWriteError
 | 
				
			||||||
 | 
					from ._utils import extract_error_code, extract_err_message
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					_socket.py
 | 
				
			||||||
 | 
					websocket - WebSocket client library for Python
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Copyright 2024 engn33r
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					You may obtain a copy of the License at
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					limitations under the License.
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					DEFAULT_SOCKET_OPTION = [(socket.SOL_TCP, socket.TCP_NODELAY, 1)]
 | 
				
			||||||
 | 
					if hasattr(socket, "SO_KEEPALIVE"):
 | 
				
			||||||
 | 
					    DEFAULT_SOCKET_OPTION.append((socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1))
 | 
				
			||||||
 | 
					if hasattr(socket, "TCP_KEEPIDLE"):
 | 
				
			||||||
 | 
					    DEFAULT_SOCKET_OPTION.append((socket.SOL_TCP, socket.TCP_KEEPIDLE, 30))
 | 
				
			||||||
 | 
					if hasattr(socket, "TCP_KEEPINTVL"):
 | 
				
			||||||
 | 
					    DEFAULT_SOCKET_OPTION.append((socket.SOL_TCP, socket.TCP_KEEPINTVL, 10))
 | 
				
			||||||
 | 
					if hasattr(socket, "TCP_KEEPCNT"):
 | 
				
			||||||
 | 
					    DEFAULT_SOCKET_OPTION.append((socket.SOL_TCP, socket.TCP_KEEPCNT, 3))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					_default_timeout = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					__all__ = [
 | 
				
			||||||
 | 
					    "DEFAULT_SOCKET_OPTION",
 | 
				
			||||||
 | 
					    "sock_opt",
 | 
				
			||||||
 | 
					    "setdefaulttimeout",
 | 
				
			||||||
 | 
					    "getdefaulttimeout",
 | 
				
			||||||
 | 
					    "recv",
 | 
				
			||||||
 | 
					    "recv_line",
 | 
				
			||||||
 | 
					    "send",
 | 
				
			||||||
 | 
					]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class sock_opt:
 | 
				
			||||||
 | 
					    def __init__(self, sockopt: list, sslopt: dict) -> None:
 | 
				
			||||||
 | 
					        if sockopt is None:
 | 
				
			||||||
 | 
					            sockopt = []
 | 
				
			||||||
 | 
					        if sslopt is None:
 | 
				
			||||||
 | 
					            sslopt = {}
 | 
				
			||||||
 | 
					        self.sockopt = sockopt
 | 
				
			||||||
 | 
					        self.sslopt = sslopt
 | 
				
			||||||
 | 
					        self.timeout = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def setdefaulttimeout(timeout: Union[int, float, None]) -> None:
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Set the global timeout setting to connect.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Parameters
 | 
				
			||||||
 | 
					    ----------
 | 
				
			||||||
 | 
					    timeout: int or float
 | 
				
			||||||
 | 
					        default socket timeout time (in seconds)
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    global _default_timeout
 | 
				
			||||||
 | 
					    _default_timeout = timeout
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def getdefaulttimeout() -> Union[int, float, None]:
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Get default timeout
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Returns
 | 
				
			||||||
 | 
					    ----------
 | 
				
			||||||
 | 
					    _default_timeout: int or float
 | 
				
			||||||
 | 
					        Return the global timeout setting (in seconds) to connect.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    return _default_timeout
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def recv(sock: socket.socket, bufsize: int) -> bytes:
 | 
				
			||||||
 | 
					    if not sock:
 | 
				
			||||||
 | 
					        raise WebSocketConnectionClosedException("socket is already closed.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _recv():
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            return sock.recv(bufsize)
 | 
				
			||||||
 | 
					        except SSLWantReadError:
 | 
				
			||||||
 | 
					            pass
 | 
				
			||||||
 | 
					        except socket.error as exc:
 | 
				
			||||||
 | 
					            error_code = extract_error_code(exc)
 | 
				
			||||||
 | 
					            if error_code not in [errno.EAGAIN, errno.EWOULDBLOCK]:
 | 
				
			||||||
 | 
					                raise
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        sel = selectors.DefaultSelector()
 | 
				
			||||||
 | 
					        sel.register(sock, selectors.EVENT_READ)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        r = sel.select(sock.gettimeout())
 | 
				
			||||||
 | 
					        sel.close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if r:
 | 
				
			||||||
 | 
					            return sock.recv(bufsize)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    try:
 | 
				
			||||||
 | 
					        if sock.gettimeout() == 0:
 | 
				
			||||||
 | 
					            bytes_ = sock.recv(bufsize)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            bytes_ = _recv()
 | 
				
			||||||
 | 
					    except TimeoutError:
 | 
				
			||||||
 | 
					        raise WebSocketTimeoutException("Connection timed out")
 | 
				
			||||||
 | 
					    except socket.timeout as e:
 | 
				
			||||||
 | 
					        message = extract_err_message(e)
 | 
				
			||||||
 | 
					        raise WebSocketTimeoutException(message)
 | 
				
			||||||
 | 
					    except SSLError as e:
 | 
				
			||||||
 | 
					        message = extract_err_message(e)
 | 
				
			||||||
 | 
					        if isinstance(message, str) and "timed out" in message:
 | 
				
			||||||
 | 
					            raise WebSocketTimeoutException(message)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            raise
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if not bytes_:
 | 
				
			||||||
 | 
					        raise WebSocketConnectionClosedException("Connection to remote host was lost.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return bytes_
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def recv_line(sock: socket.socket) -> bytes:
 | 
				
			||||||
 | 
					    line = []
 | 
				
			||||||
 | 
					    while True:
 | 
				
			||||||
 | 
					        c = recv(sock, 1)
 | 
				
			||||||
 | 
					        line.append(c)
 | 
				
			||||||
 | 
					        if c == b"\n":
 | 
				
			||||||
 | 
					            break
 | 
				
			||||||
 | 
					    return b"".join(line)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def send(sock: socket.socket, data: Union[bytes, str]) -> int:
 | 
				
			||||||
 | 
					    if isinstance(data, str):
 | 
				
			||||||
 | 
					        data = data.encode("utf-8")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if not sock:
 | 
				
			||||||
 | 
					        raise WebSocketConnectionClosedException("socket is already closed.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _send():
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            return sock.send(data)
 | 
				
			||||||
 | 
					        except SSLWantWriteError:
 | 
				
			||||||
 | 
					            pass
 | 
				
			||||||
 | 
					        except socket.error as exc:
 | 
				
			||||||
 | 
					            error_code = extract_error_code(exc)
 | 
				
			||||||
 | 
					            if error_code is None:
 | 
				
			||||||
 | 
					                raise
 | 
				
			||||||
 | 
					            if error_code not in [errno.EAGAIN, errno.EWOULDBLOCK]:
 | 
				
			||||||
 | 
					                raise
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        sel = selectors.DefaultSelector()
 | 
				
			||||||
 | 
					        sel.register(sock, selectors.EVENT_WRITE)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        w = sel.select(sock.gettimeout())
 | 
				
			||||||
 | 
					        sel.close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if w:
 | 
				
			||||||
 | 
					            return sock.send(data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    try:
 | 
				
			||||||
 | 
					        if sock.gettimeout() == 0:
 | 
				
			||||||
 | 
					            return sock.send(data)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            return _send()
 | 
				
			||||||
 | 
					    except socket.timeout as e:
 | 
				
			||||||
 | 
					        message = extract_err_message(e)
 | 
				
			||||||
 | 
					        raise WebSocketTimeoutException(message)
 | 
				
			||||||
 | 
					    except Exception as e:
 | 
				
			||||||
 | 
					        message = extract_err_message(e)
 | 
				
			||||||
 | 
					        if isinstance(message, str) and "timed out" in message:
 | 
				
			||||||
 | 
					            raise WebSocketTimeoutException(message)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            raise
 | 
				
			||||||
							
								
								
									
										49
									
								
								src/libs/websocket/_ssl_compat.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										49
									
								
								src/libs/websocket/_ssl_compat.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,49 @@
 | 
				
			|||||||
 | 
					"""
 | 
				
			||||||
 | 
					_ssl_compat.py
 | 
				
			||||||
 | 
					websocket - WebSocket client library for Python
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Copyright 2024 engn33r
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					You may obtain a copy of the License at
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					limitations under the License.
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					__all__ = [
 | 
				
			||||||
 | 
					    "HAVE_SSL",
 | 
				
			||||||
 | 
					    "ssl",
 | 
				
			||||||
 | 
					    "SSLError",
 | 
				
			||||||
 | 
					    "SSLEOFError",
 | 
				
			||||||
 | 
					    "SSLWantReadError",
 | 
				
			||||||
 | 
					    "SSLWantWriteError",
 | 
				
			||||||
 | 
					]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					try:
 | 
				
			||||||
 | 
					    import ssl
 | 
				
			||||||
 | 
					    from ssl import SSLError, SSLEOFError, SSLWantReadError, SSLWantWriteError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    HAVE_SSL = True
 | 
				
			||||||
 | 
					except ImportError:
 | 
				
			||||||
 | 
					    # dummy class of SSLError for environment without ssl support
 | 
				
			||||||
 | 
					    class SSLError(Exception):
 | 
				
			||||||
 | 
					        pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    class SSLEOFError(Exception):
 | 
				
			||||||
 | 
					        pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    class SSLWantReadError(Exception):
 | 
				
			||||||
 | 
					        pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    class SSLWantWriteError(Exception):
 | 
				
			||||||
 | 
					        pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    ssl = None
 | 
				
			||||||
 | 
					    HAVE_SSL = False
 | 
				
			||||||
							
								
								
									
										190
									
								
								src/libs/websocket/_url.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										190
									
								
								src/libs/websocket/_url.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,190 @@
 | 
				
			|||||||
 | 
					import os
 | 
				
			||||||
 | 
					import socket
 | 
				
			||||||
 | 
					import struct
 | 
				
			||||||
 | 
					from typing import Optional
 | 
				
			||||||
 | 
					from urllib.parse import unquote, urlparse
 | 
				
			||||||
 | 
					from ._exceptions import WebSocketProxyException
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					_url.py
 | 
				
			||||||
 | 
					websocket - WebSocket client library for Python
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Copyright 2024 engn33r
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					You may obtain a copy of the License at
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					limitations under the License.
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					__all__ = ["parse_url", "get_proxy_info"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def parse_url(url: str) -> tuple:
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    parse url and the result is tuple of
 | 
				
			||||||
 | 
					    (hostname, port, resource path and the flag of secure mode)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Parameters
 | 
				
			||||||
 | 
					    ----------
 | 
				
			||||||
 | 
					    url: str
 | 
				
			||||||
 | 
					        url string.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    if ":" not in url:
 | 
				
			||||||
 | 
					        raise ValueError("url is invalid")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    scheme, url = url.split(":", 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    parsed = urlparse(url, scheme="http")
 | 
				
			||||||
 | 
					    if parsed.hostname:
 | 
				
			||||||
 | 
					        hostname = parsed.hostname
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        raise ValueError("hostname is invalid")
 | 
				
			||||||
 | 
					    port = 0
 | 
				
			||||||
 | 
					    if parsed.port:
 | 
				
			||||||
 | 
					        port = parsed.port
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    is_secure = False
 | 
				
			||||||
 | 
					    if scheme == "ws":
 | 
				
			||||||
 | 
					        if not port:
 | 
				
			||||||
 | 
					            port = 80
 | 
				
			||||||
 | 
					    elif scheme == "wss":
 | 
				
			||||||
 | 
					        is_secure = True
 | 
				
			||||||
 | 
					        if not port:
 | 
				
			||||||
 | 
					            port = 443
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        raise ValueError("scheme %s is invalid" % scheme)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if parsed.path:
 | 
				
			||||||
 | 
					        resource = parsed.path
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        resource = "/"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if parsed.query:
 | 
				
			||||||
 | 
					        resource += f"?{parsed.query}"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return hostname, port, resource, is_secure
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					DEFAULT_NO_PROXY_HOST = ["localhost", "127.0.0.1"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _is_ip_address(addr: str) -> bool:
 | 
				
			||||||
 | 
					    try:
 | 
				
			||||||
 | 
					        socket.inet_aton(addr)
 | 
				
			||||||
 | 
					    except socket.error:
 | 
				
			||||||
 | 
					        return False
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        return True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _is_subnet_address(hostname: str) -> bool:
 | 
				
			||||||
 | 
					    try:
 | 
				
			||||||
 | 
					        addr, netmask = hostname.split("/")
 | 
				
			||||||
 | 
					        return _is_ip_address(addr) and 0 <= int(netmask) < 32
 | 
				
			||||||
 | 
					    except ValueError:
 | 
				
			||||||
 | 
					        return False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _is_address_in_network(ip: str, net: str) -> bool:
 | 
				
			||||||
 | 
					    ipaddr: int = struct.unpack("!I", socket.inet_aton(ip))[0]
 | 
				
			||||||
 | 
					    netaddr, netmask = net.split("/")
 | 
				
			||||||
 | 
					    netaddr: int = struct.unpack("!I", socket.inet_aton(netaddr))[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    netmask = (0xFFFFFFFF << (32 - int(netmask))) & 0xFFFFFFFF
 | 
				
			||||||
 | 
					    return ipaddr & netmask == netaddr
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _is_no_proxy_host(hostname: str, no_proxy: Optional[list]) -> bool:
 | 
				
			||||||
 | 
					    if not no_proxy:
 | 
				
			||||||
 | 
					        if v := os.environ.get("no_proxy", os.environ.get("NO_PROXY", "")).replace(
 | 
				
			||||||
 | 
					            " ", ""
 | 
				
			||||||
 | 
					        ):
 | 
				
			||||||
 | 
					            no_proxy = v.split(",")
 | 
				
			||||||
 | 
					    if not no_proxy:
 | 
				
			||||||
 | 
					        no_proxy = DEFAULT_NO_PROXY_HOST
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if "*" in no_proxy:
 | 
				
			||||||
 | 
					        return True
 | 
				
			||||||
 | 
					    if hostname in no_proxy:
 | 
				
			||||||
 | 
					        return True
 | 
				
			||||||
 | 
					    if _is_ip_address(hostname):
 | 
				
			||||||
 | 
					        return any(
 | 
				
			||||||
 | 
					            [
 | 
				
			||||||
 | 
					                _is_address_in_network(hostname, subnet)
 | 
				
			||||||
 | 
					                for subnet in no_proxy
 | 
				
			||||||
 | 
					                if _is_subnet_address(subnet)
 | 
				
			||||||
 | 
					            ]
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					    for domain in [domain for domain in no_proxy if domain.startswith(".")]:
 | 
				
			||||||
 | 
					        if hostname.endswith(domain):
 | 
				
			||||||
 | 
					            return True
 | 
				
			||||||
 | 
					    return False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_proxy_info(
 | 
				
			||||||
 | 
					    hostname: str,
 | 
				
			||||||
 | 
					    is_secure: bool,
 | 
				
			||||||
 | 
					    proxy_host: Optional[str] = None,
 | 
				
			||||||
 | 
					    proxy_port: int = 0,
 | 
				
			||||||
 | 
					    proxy_auth: Optional[tuple] = None,
 | 
				
			||||||
 | 
					    no_proxy: Optional[list] = None,
 | 
				
			||||||
 | 
					    proxy_type: str = "http",
 | 
				
			||||||
 | 
					) -> tuple:
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Try to retrieve proxy host and port from environment
 | 
				
			||||||
 | 
					    if not provided in options.
 | 
				
			||||||
 | 
					    Result is (proxy_host, proxy_port, proxy_auth).
 | 
				
			||||||
 | 
					    proxy_auth is tuple of username and password
 | 
				
			||||||
 | 
					    of proxy authentication information.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Parameters
 | 
				
			||||||
 | 
					    ----------
 | 
				
			||||||
 | 
					    hostname: str
 | 
				
			||||||
 | 
					        Websocket server name.
 | 
				
			||||||
 | 
					    is_secure: bool
 | 
				
			||||||
 | 
					        Is the connection secure? (wss) looks for "https_proxy" in env
 | 
				
			||||||
 | 
					        instead of "http_proxy"
 | 
				
			||||||
 | 
					    proxy_host: str
 | 
				
			||||||
 | 
					        http proxy host name.
 | 
				
			||||||
 | 
					    proxy_port: str or int
 | 
				
			||||||
 | 
					        http proxy port.
 | 
				
			||||||
 | 
					    no_proxy: list
 | 
				
			||||||
 | 
					        Whitelisted host names that don't use the proxy.
 | 
				
			||||||
 | 
					    proxy_auth: tuple
 | 
				
			||||||
 | 
					        HTTP proxy auth information. Tuple of username and password. Default is None.
 | 
				
			||||||
 | 
					    proxy_type: str
 | 
				
			||||||
 | 
					        Specify the proxy protocol (http, socks4, socks4a, socks5, socks5h). Default is "http".
 | 
				
			||||||
 | 
					        Use socks4a or socks5h if you want to send DNS requests through the proxy.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    if _is_no_proxy_host(hostname, no_proxy):
 | 
				
			||||||
 | 
					        return None, 0, None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if proxy_host:
 | 
				
			||||||
 | 
					        if not proxy_port:
 | 
				
			||||||
 | 
					            raise WebSocketProxyException("Cannot use port 0 when proxy_host specified")
 | 
				
			||||||
 | 
					        port = proxy_port
 | 
				
			||||||
 | 
					        auth = proxy_auth
 | 
				
			||||||
 | 
					        return proxy_host, port, auth
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    env_key = "https_proxy" if is_secure else "http_proxy"
 | 
				
			||||||
 | 
					    value = os.environ.get(env_key, os.environ.get(env_key.upper(), "")).replace(
 | 
				
			||||||
 | 
					        " ", ""
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    if value:
 | 
				
			||||||
 | 
					        proxy = urlparse(value)
 | 
				
			||||||
 | 
					        auth = (
 | 
				
			||||||
 | 
					            (unquote(proxy.username), unquote(proxy.password))
 | 
				
			||||||
 | 
					            if proxy.username
 | 
				
			||||||
 | 
					            else None
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        return proxy.hostname, proxy.port, auth
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return None, 0, None
 | 
				
			||||||
							
								
								
									
										459
									
								
								src/libs/websocket/_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										459
									
								
								src/libs/websocket/_utils.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,459 @@
 | 
				
			|||||||
 | 
					from typing import Union
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					_url.py
 | 
				
			||||||
 | 
					websocket - WebSocket client library for Python
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Copyright 2024 engn33r
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					You may obtain a copy of the License at
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					limitations under the License.
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					__all__ = ["NoLock", "validate_utf8", "extract_err_message", "extract_error_code"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class NoLock:
 | 
				
			||||||
 | 
					    def __enter__(self) -> None:
 | 
				
			||||||
 | 
					        pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __exit__(self, exc_type, exc_value, traceback) -> None:
 | 
				
			||||||
 | 
					        pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					try:
 | 
				
			||||||
 | 
					    # If wsaccel is available we use compiled routines to validate UTF-8
 | 
				
			||||||
 | 
					    # strings.
 | 
				
			||||||
 | 
					    from wsaccel.utf8validator import Utf8Validator
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _validate_utf8(utfbytes: Union[str, bytes]) -> bool:
 | 
				
			||||||
 | 
					        result: bool = Utf8Validator().validate(utfbytes)[0]
 | 
				
			||||||
 | 
					        return result
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					except ImportError:
 | 
				
			||||||
 | 
					    # UTF-8 validator
 | 
				
			||||||
 | 
					    # python implementation of http://bjoern.hoehrmann.de/utf-8/decoder/dfa/
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    _UTF8_ACCEPT = 0
 | 
				
			||||||
 | 
					    _UTF8_REJECT = 12
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    _UTF8D = [
 | 
				
			||||||
 | 
					        # The first part of the table maps bytes to character classes that
 | 
				
			||||||
 | 
					        # to reduce the size of the transition table and create bitmasks.
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        1,
 | 
				
			||||||
 | 
					        1,
 | 
				
			||||||
 | 
					        1,
 | 
				
			||||||
 | 
					        1,
 | 
				
			||||||
 | 
					        1,
 | 
				
			||||||
 | 
					        1,
 | 
				
			||||||
 | 
					        1,
 | 
				
			||||||
 | 
					        1,
 | 
				
			||||||
 | 
					        1,
 | 
				
			||||||
 | 
					        1,
 | 
				
			||||||
 | 
					        1,
 | 
				
			||||||
 | 
					        1,
 | 
				
			||||||
 | 
					        1,
 | 
				
			||||||
 | 
					        1,
 | 
				
			||||||
 | 
					        1,
 | 
				
			||||||
 | 
					        1,
 | 
				
			||||||
 | 
					        9,
 | 
				
			||||||
 | 
					        9,
 | 
				
			||||||
 | 
					        9,
 | 
				
			||||||
 | 
					        9,
 | 
				
			||||||
 | 
					        9,
 | 
				
			||||||
 | 
					        9,
 | 
				
			||||||
 | 
					        9,
 | 
				
			||||||
 | 
					        9,
 | 
				
			||||||
 | 
					        9,
 | 
				
			||||||
 | 
					        9,
 | 
				
			||||||
 | 
					        9,
 | 
				
			||||||
 | 
					        9,
 | 
				
			||||||
 | 
					        9,
 | 
				
			||||||
 | 
					        9,
 | 
				
			||||||
 | 
					        9,
 | 
				
			||||||
 | 
					        9,
 | 
				
			||||||
 | 
					        7,
 | 
				
			||||||
 | 
					        7,
 | 
				
			||||||
 | 
					        7,
 | 
				
			||||||
 | 
					        7,
 | 
				
			||||||
 | 
					        7,
 | 
				
			||||||
 | 
					        7,
 | 
				
			||||||
 | 
					        7,
 | 
				
			||||||
 | 
					        7,
 | 
				
			||||||
 | 
					        7,
 | 
				
			||||||
 | 
					        7,
 | 
				
			||||||
 | 
					        7,
 | 
				
			||||||
 | 
					        7,
 | 
				
			||||||
 | 
					        7,
 | 
				
			||||||
 | 
					        7,
 | 
				
			||||||
 | 
					        7,
 | 
				
			||||||
 | 
					        7,
 | 
				
			||||||
 | 
					        7,
 | 
				
			||||||
 | 
					        7,
 | 
				
			||||||
 | 
					        7,
 | 
				
			||||||
 | 
					        7,
 | 
				
			||||||
 | 
					        7,
 | 
				
			||||||
 | 
					        7,
 | 
				
			||||||
 | 
					        7,
 | 
				
			||||||
 | 
					        7,
 | 
				
			||||||
 | 
					        7,
 | 
				
			||||||
 | 
					        7,
 | 
				
			||||||
 | 
					        7,
 | 
				
			||||||
 | 
					        7,
 | 
				
			||||||
 | 
					        7,
 | 
				
			||||||
 | 
					        7,
 | 
				
			||||||
 | 
					        7,
 | 
				
			||||||
 | 
					        7,
 | 
				
			||||||
 | 
					        8,
 | 
				
			||||||
 | 
					        8,
 | 
				
			||||||
 | 
					        2,
 | 
				
			||||||
 | 
					        2,
 | 
				
			||||||
 | 
					        2,
 | 
				
			||||||
 | 
					        2,
 | 
				
			||||||
 | 
					        2,
 | 
				
			||||||
 | 
					        2,
 | 
				
			||||||
 | 
					        2,
 | 
				
			||||||
 | 
					        2,
 | 
				
			||||||
 | 
					        2,
 | 
				
			||||||
 | 
					        2,
 | 
				
			||||||
 | 
					        2,
 | 
				
			||||||
 | 
					        2,
 | 
				
			||||||
 | 
					        2,
 | 
				
			||||||
 | 
					        2,
 | 
				
			||||||
 | 
					        2,
 | 
				
			||||||
 | 
					        2,
 | 
				
			||||||
 | 
					        2,
 | 
				
			||||||
 | 
					        2,
 | 
				
			||||||
 | 
					        2,
 | 
				
			||||||
 | 
					        2,
 | 
				
			||||||
 | 
					        2,
 | 
				
			||||||
 | 
					        2,
 | 
				
			||||||
 | 
					        2,
 | 
				
			||||||
 | 
					        2,
 | 
				
			||||||
 | 
					        2,
 | 
				
			||||||
 | 
					        2,
 | 
				
			||||||
 | 
					        2,
 | 
				
			||||||
 | 
					        2,
 | 
				
			||||||
 | 
					        2,
 | 
				
			||||||
 | 
					        2,
 | 
				
			||||||
 | 
					        10,
 | 
				
			||||||
 | 
					        3,
 | 
				
			||||||
 | 
					        3,
 | 
				
			||||||
 | 
					        3,
 | 
				
			||||||
 | 
					        3,
 | 
				
			||||||
 | 
					        3,
 | 
				
			||||||
 | 
					        3,
 | 
				
			||||||
 | 
					        3,
 | 
				
			||||||
 | 
					        3,
 | 
				
			||||||
 | 
					        3,
 | 
				
			||||||
 | 
					        3,
 | 
				
			||||||
 | 
					        3,
 | 
				
			||||||
 | 
					        3,
 | 
				
			||||||
 | 
					        4,
 | 
				
			||||||
 | 
					        3,
 | 
				
			||||||
 | 
					        3,
 | 
				
			||||||
 | 
					        11,
 | 
				
			||||||
 | 
					        6,
 | 
				
			||||||
 | 
					        6,
 | 
				
			||||||
 | 
					        6,
 | 
				
			||||||
 | 
					        5,
 | 
				
			||||||
 | 
					        8,
 | 
				
			||||||
 | 
					        8,
 | 
				
			||||||
 | 
					        8,
 | 
				
			||||||
 | 
					        8,
 | 
				
			||||||
 | 
					        8,
 | 
				
			||||||
 | 
					        8,
 | 
				
			||||||
 | 
					        8,
 | 
				
			||||||
 | 
					        8,
 | 
				
			||||||
 | 
					        8,
 | 
				
			||||||
 | 
					        8,
 | 
				
			||||||
 | 
					        8,
 | 
				
			||||||
 | 
					        # The second part is a transition table that maps a combination
 | 
				
			||||||
 | 
					        # of a state of the automaton and a character class to a state.
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        24,
 | 
				
			||||||
 | 
					        36,
 | 
				
			||||||
 | 
					        60,
 | 
				
			||||||
 | 
					        96,
 | 
				
			||||||
 | 
					        84,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        48,
 | 
				
			||||||
 | 
					        72,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        24,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        24,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        24,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        24,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        24,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        24,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        36,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        36,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        36,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        36,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        36,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        36,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					        12,
 | 
				
			||||||
 | 
					    ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _decode(state: int, codep: int, ch: int) -> tuple:
 | 
				
			||||||
 | 
					        tp = _UTF8D[ch]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        codep = (
 | 
				
			||||||
 | 
					            (ch & 0x3F) | (codep << 6) if (state != _UTF8_ACCEPT) else (0xFF >> tp) & ch
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        state = _UTF8D[256 + state + tp]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return state, codep
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _validate_utf8(utfbytes: Union[str, bytes]) -> bool:
 | 
				
			||||||
 | 
					        state = _UTF8_ACCEPT
 | 
				
			||||||
 | 
					        codep = 0
 | 
				
			||||||
 | 
					        for i in utfbytes:
 | 
				
			||||||
 | 
					            state, codep = _decode(state, codep, int(i))
 | 
				
			||||||
 | 
					            if state == _UTF8_REJECT:
 | 
				
			||||||
 | 
					                return False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def validate_utf8(utfbytes: Union[str, bytes]) -> bool:
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    validate utf8 byte string.
 | 
				
			||||||
 | 
					    utfbytes: utf byte string to check.
 | 
				
			||||||
 | 
					    return value: if valid utf8 string, return true. Otherwise, return false.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    return _validate_utf8(utfbytes)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def extract_err_message(exception: Exception) -> Union[str, None]:
 | 
				
			||||||
 | 
					    if exception.args:
 | 
				
			||||||
 | 
					        exception_message: str = exception.args[0]
 | 
				
			||||||
 | 
					        return exception_message
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        return None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def extract_error_code(exception: Exception) -> Union[int, None]:
 | 
				
			||||||
 | 
					    if exception.args and len(exception.args) > 1:
 | 
				
			||||||
 | 
					        return exception.args[0] if isinstance(exception.args[0], int) else None
 | 
				
			||||||
							
								
								
									
										244
									
								
								src/libs/websocket/_wsdump.py
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										244
									
								
								src/libs/websocket/_wsdump.py
									
									
									
									
									
										Executable file
									
								
							@@ -0,0 +1,244 @@
 | 
				
			|||||||
 | 
					#!/usr/bin/env python3
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					wsdump.py
 | 
				
			||||||
 | 
					websocket - WebSocket client library for Python
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Copyright 2024 engn33r
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					You may obtain a copy of the License at
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					limitations under the License.
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import argparse
 | 
				
			||||||
 | 
					import code
 | 
				
			||||||
 | 
					import gzip
 | 
				
			||||||
 | 
					import ssl
 | 
				
			||||||
 | 
					import sys
 | 
				
			||||||
 | 
					import threading
 | 
				
			||||||
 | 
					import time
 | 
				
			||||||
 | 
					import zlib
 | 
				
			||||||
 | 
					from urllib.parse import urlparse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import websocket
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					try:
 | 
				
			||||||
 | 
					    import readline
 | 
				
			||||||
 | 
					except ImportError:
 | 
				
			||||||
 | 
					    pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_encoding() -> str:
 | 
				
			||||||
 | 
					    encoding = getattr(sys.stdin, "encoding", "")
 | 
				
			||||||
 | 
					    if not encoding:
 | 
				
			||||||
 | 
					        return "utf-8"
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        return encoding.lower()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					OPCODE_DATA = (websocket.ABNF.OPCODE_TEXT, websocket.ABNF.OPCODE_BINARY)
 | 
				
			||||||
 | 
					ENCODING = get_encoding()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class VAction(argparse.Action):
 | 
				
			||||||
 | 
					    def __call__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        parser: argparse.Namespace,
 | 
				
			||||||
 | 
					        args: tuple,
 | 
				
			||||||
 | 
					        values: str,
 | 
				
			||||||
 | 
					        option_string: str = None,
 | 
				
			||||||
 | 
					    ) -> None:
 | 
				
			||||||
 | 
					        if values is None:
 | 
				
			||||||
 | 
					            values = "1"
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            values = int(values)
 | 
				
			||||||
 | 
					        except ValueError:
 | 
				
			||||||
 | 
					            values = values.count("v") + 1
 | 
				
			||||||
 | 
					        setattr(args, self.dest, values)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def parse_args() -> argparse.Namespace:
 | 
				
			||||||
 | 
					    parser = argparse.ArgumentParser(description="WebSocket Simple Dump Tool")
 | 
				
			||||||
 | 
					    parser.add_argument(
 | 
				
			||||||
 | 
					        "url", metavar="ws_url", help="websocket url. ex. ws://echo.websocket.events/"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    parser.add_argument("-p", "--proxy", help="proxy url. ex. http://127.0.0.1:8080")
 | 
				
			||||||
 | 
					    parser.add_argument(
 | 
				
			||||||
 | 
					        "-v",
 | 
				
			||||||
 | 
					        "--verbose",
 | 
				
			||||||
 | 
					        default=0,
 | 
				
			||||||
 | 
					        nargs="?",
 | 
				
			||||||
 | 
					        action=VAction,
 | 
				
			||||||
 | 
					        dest="verbose",
 | 
				
			||||||
 | 
					        help="set verbose mode. If set to 1, show opcode. "
 | 
				
			||||||
 | 
					        "If set to 2, enable to trace  websocket module",
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    parser.add_argument(
 | 
				
			||||||
 | 
					        "-n", "--nocert", action="store_true", help="Ignore invalid SSL cert"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    parser.add_argument("-r", "--raw", action="store_true", help="raw output")
 | 
				
			||||||
 | 
					    parser.add_argument("-s", "--subprotocols", nargs="*", help="Set subprotocols")
 | 
				
			||||||
 | 
					    parser.add_argument("-o", "--origin", help="Set origin")
 | 
				
			||||||
 | 
					    parser.add_argument(
 | 
				
			||||||
 | 
					        "--eof-wait",
 | 
				
			||||||
 | 
					        default=0,
 | 
				
			||||||
 | 
					        type=int,
 | 
				
			||||||
 | 
					        help="wait time(second) after 'EOF' received.",
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    parser.add_argument("-t", "--text", help="Send initial text")
 | 
				
			||||||
 | 
					    parser.add_argument(
 | 
				
			||||||
 | 
					        "--timings", action="store_true", help="Print timings in seconds"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    parser.add_argument("--headers", help="Set custom headers. Use ',' as separator")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return parser.parse_args()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class RawInput:
 | 
				
			||||||
 | 
					    def raw_input(self, prompt: str = "") -> str:
 | 
				
			||||||
 | 
					        line = input(prompt)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if ENCODING and ENCODING != "utf-8" and not isinstance(line, str):
 | 
				
			||||||
 | 
					            line = line.decode(ENCODING).encode("utf-8")
 | 
				
			||||||
 | 
					        elif isinstance(line, str):
 | 
				
			||||||
 | 
					            line = line.encode("utf-8")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return line
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class InteractiveConsole(RawInput, code.InteractiveConsole):
 | 
				
			||||||
 | 
					    def write(self, data: str) -> None:
 | 
				
			||||||
 | 
					        sys.stdout.write("\033[2K\033[E")
 | 
				
			||||||
 | 
					        # sys.stdout.write("\n")
 | 
				
			||||||
 | 
					        sys.stdout.write("\033[34m< " + data + "\033[39m")
 | 
				
			||||||
 | 
					        sys.stdout.write("\n> ")
 | 
				
			||||||
 | 
					        sys.stdout.flush()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def read(self) -> str:
 | 
				
			||||||
 | 
					        return self.raw_input("> ")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class NonInteractive(RawInput):
 | 
				
			||||||
 | 
					    def write(self, data: str) -> None:
 | 
				
			||||||
 | 
					        sys.stdout.write(data)
 | 
				
			||||||
 | 
					        sys.stdout.write("\n")
 | 
				
			||||||
 | 
					        sys.stdout.flush()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def read(self) -> str:
 | 
				
			||||||
 | 
					        return self.raw_input("")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def main() -> None:
 | 
				
			||||||
 | 
					    start_time = time.time()
 | 
				
			||||||
 | 
					    args = parse_args()
 | 
				
			||||||
 | 
					    if args.verbose > 1:
 | 
				
			||||||
 | 
					        websocket.enableTrace(True)
 | 
				
			||||||
 | 
					    options = {}
 | 
				
			||||||
 | 
					    if args.proxy:
 | 
				
			||||||
 | 
					        p = urlparse(args.proxy)
 | 
				
			||||||
 | 
					        options["http_proxy_host"] = p.hostname
 | 
				
			||||||
 | 
					        options["http_proxy_port"] = p.port
 | 
				
			||||||
 | 
					    if args.origin:
 | 
				
			||||||
 | 
					        options["origin"] = args.origin
 | 
				
			||||||
 | 
					    if args.subprotocols:
 | 
				
			||||||
 | 
					        options["subprotocols"] = args.subprotocols
 | 
				
			||||||
 | 
					    opts = {}
 | 
				
			||||||
 | 
					    if args.nocert:
 | 
				
			||||||
 | 
					        opts = {"cert_reqs": ssl.CERT_NONE, "check_hostname": False}
 | 
				
			||||||
 | 
					    if args.headers:
 | 
				
			||||||
 | 
					        options["header"] = list(map(str.strip, args.headers.split(",")))
 | 
				
			||||||
 | 
					    ws = websocket.create_connection(args.url, sslopt=opts, **options)
 | 
				
			||||||
 | 
					    if args.raw:
 | 
				
			||||||
 | 
					        console = NonInteractive()
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        console = InteractiveConsole()
 | 
				
			||||||
 | 
					        print("Press Ctrl+C to quit")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def recv() -> tuple:
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            frame = ws.recv_frame()
 | 
				
			||||||
 | 
					        except websocket.WebSocketException:
 | 
				
			||||||
 | 
					            return websocket.ABNF.OPCODE_CLOSE, ""
 | 
				
			||||||
 | 
					        if not frame:
 | 
				
			||||||
 | 
					            raise websocket.WebSocketException(f"Not a valid frame {frame}")
 | 
				
			||||||
 | 
					        elif frame.opcode in OPCODE_DATA:
 | 
				
			||||||
 | 
					            return frame.opcode, frame.data
 | 
				
			||||||
 | 
					        elif frame.opcode == websocket.ABNF.OPCODE_CLOSE:
 | 
				
			||||||
 | 
					            ws.send_close()
 | 
				
			||||||
 | 
					            return frame.opcode, ""
 | 
				
			||||||
 | 
					        elif frame.opcode == websocket.ABNF.OPCODE_PING:
 | 
				
			||||||
 | 
					            ws.pong(frame.data)
 | 
				
			||||||
 | 
					            return frame.opcode, frame.data
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return frame.opcode, frame.data
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def recv_ws() -> None:
 | 
				
			||||||
 | 
					        while True:
 | 
				
			||||||
 | 
					            opcode, data = recv()
 | 
				
			||||||
 | 
					            msg = None
 | 
				
			||||||
 | 
					            if opcode == websocket.ABNF.OPCODE_TEXT and isinstance(data, bytes):
 | 
				
			||||||
 | 
					                data = str(data, "utf-8")
 | 
				
			||||||
 | 
					            if (
 | 
				
			||||||
 | 
					                isinstance(data, bytes) and len(data) > 2 and data[:2] == b"\037\213"
 | 
				
			||||||
 | 
					            ):  # gzip magick
 | 
				
			||||||
 | 
					                try:
 | 
				
			||||||
 | 
					                    data = "[gzip] " + str(gzip.decompress(data), "utf-8")
 | 
				
			||||||
 | 
					                except:
 | 
				
			||||||
 | 
					                    pass
 | 
				
			||||||
 | 
					            elif isinstance(data, bytes):
 | 
				
			||||||
 | 
					                try:
 | 
				
			||||||
 | 
					                    data = "[zlib] " + str(
 | 
				
			||||||
 | 
					                        zlib.decompress(data, -zlib.MAX_WBITS), "utf-8"
 | 
				
			||||||
 | 
					                    )
 | 
				
			||||||
 | 
					                except:
 | 
				
			||||||
 | 
					                    pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if isinstance(data, bytes):
 | 
				
			||||||
 | 
					                data = repr(data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if args.verbose:
 | 
				
			||||||
 | 
					                msg = f"{websocket.ABNF.OPCODE_MAP.get(opcode)}: {data}"
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                msg = data
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if msg is not None:
 | 
				
			||||||
 | 
					                if args.timings:
 | 
				
			||||||
 | 
					                    console.write(f"{time.time() - start_time}: {msg}")
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
 | 
					                    console.write(msg)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if opcode == websocket.ABNF.OPCODE_CLOSE:
 | 
				
			||||||
 | 
					                break
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    thread = threading.Thread(target=recv_ws)
 | 
				
			||||||
 | 
					    thread.daemon = True
 | 
				
			||||||
 | 
					    thread.start()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if args.text:
 | 
				
			||||||
 | 
					        ws.send(args.text)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    while True:
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            message = console.read()
 | 
				
			||||||
 | 
					            ws.send(message)
 | 
				
			||||||
 | 
					        except KeyboardInterrupt:
 | 
				
			||||||
 | 
					            return
 | 
				
			||||||
 | 
					        except EOFError:
 | 
				
			||||||
 | 
					            time.sleep(args.eof_wait)
 | 
				
			||||||
 | 
					            return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
					    try:
 | 
				
			||||||
 | 
					        main()
 | 
				
			||||||
 | 
					    except Exception as e:
 | 
				
			||||||
 | 
					        print(e)
 | 
				
			||||||
							
								
								
									
										0
									
								
								src/libs/websockets/py.typed → src/libs/websocket/py.typed
									
									
									
									
									
										
										
										Executable file → Normal file
									
								
							
							
						
						
									
										0
									
								
								src/libs/websockets/py.typed → src/libs/websocket/py.typed
									
									
									
									
									
										
										
										Executable file → Normal file
									
								
							
							
								
								
									
										0
									
								
								src/libs/websockets/asyncio/__init__.py → src/libs/websocket/tests/__init__.py
									
									
									
									
									
										
										
										Executable file → Normal file
									
								
							
							
						
						
									
										0
									
								
								src/libs/websockets/asyncio/__init__.py → src/libs/websocket/tests/__init__.py
									
									
									
									
									
										
										
										Executable file → Normal file
									
								
							
							
								
								
									
										6
									
								
								src/libs/websocket/tests/data/header01.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								src/libs/websocket/tests/data/header01.txt
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,6 @@
 | 
				
			|||||||
 | 
					HTTP/1.1 101 WebSocket Protocol Handshake
 | 
				
			||||||
 | 
					Connection: Upgrade 
 | 
				
			||||||
 | 
					Upgrade: WebSocket
 | 
				
			||||||
 | 
					Sec-WebSocket-Accept: Kxep+hNu9n51529fGidYu7a3wO0=
 | 
				
			||||||
 | 
					some_header: something
 | 
				
			||||||
 | 
					
 | 
				
			||||||
							
								
								
									
										6
									
								
								src/libs/websocket/tests/data/header02.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								src/libs/websocket/tests/data/header02.txt
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,6 @@
 | 
				
			|||||||
 | 
					HTTP/1.1 101 WebSocket Protocol Handshake
 | 
				
			||||||
 | 
					Connection: Upgrade
 | 
				
			||||||
 | 
					Upgrade WebSocket
 | 
				
			||||||
 | 
					Sec-WebSocket-Accept: Kxep+hNu9n51529fGidYu7a3wO0=
 | 
				
			||||||
 | 
					some_header: something
 | 
				
			||||||
 | 
					
 | 
				
			||||||
							
								
								
									
										8
									
								
								src/libs/websocket/tests/data/header03.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										8
									
								
								src/libs/websocket/tests/data/header03.txt
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,8 @@
 | 
				
			|||||||
 | 
					HTTP/1.1 101 WebSocket Protocol Handshake
 | 
				
			||||||
 | 
					Connection: Upgrade, Keep-Alive
 | 
				
			||||||
 | 
					Upgrade: WebSocket
 | 
				
			||||||
 | 
					Sec-WebSocket-Accept: Kxep+hNu9n51529fGidYu7a3wO0=
 | 
				
			||||||
 | 
					Set-Cookie: Token=ABCDE
 | 
				
			||||||
 | 
					Set-Cookie: Token=FGHIJ
 | 
				
			||||||
 | 
					some_header: something
 | 
				
			||||||
 | 
					
 | 
				
			||||||
							
								
								
									
										23
									
								
								src/libs/websocket/tests/echo-server.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										23
									
								
								src/libs/websocket/tests/echo-server.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,23 @@
 | 
				
			|||||||
 | 
					#!/usr/bin/env python
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# From https://github.com/aaugustin/websockets/blob/main/example/echo.py
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import asyncio
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import websockets
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					LOCAL_WS_SERVER_PORT = int(os.environ.get("LOCAL_WS_SERVER_PORT", "8765"))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					async def echo(websocket):
 | 
				
			||||||
 | 
					    async for message in websocket:
 | 
				
			||||||
 | 
					        await websocket.send(message)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					async def main():
 | 
				
			||||||
 | 
					    async with websockets.serve(echo, "localhost", LOCAL_WS_SERVER_PORT):
 | 
				
			||||||
 | 
					        await asyncio.Future()  # run forever
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					asyncio.run(main())
 | 
				
			||||||
							
								
								
									
										125
									
								
								src/libs/websocket/tests/test_abnf.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										125
									
								
								src/libs/websocket/tests/test_abnf.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,125 @@
 | 
				
			|||||||
 | 
					# -*- coding: utf-8 -*-
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					import unittest
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from websocket._abnf import ABNF, frame_buffer
 | 
				
			||||||
 | 
					from websocket._exceptions import WebSocketProtocolException
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					test_abnf.py
 | 
				
			||||||
 | 
					websocket - WebSocket client library for Python
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Copyright 2024 engn33r
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					You may obtain a copy of the License at
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					limitations under the License.
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ABNFTest(unittest.TestCase):
 | 
				
			||||||
 | 
					    def test_init(self):
 | 
				
			||||||
 | 
					        a = ABNF(0, 0, 0, 0, opcode=ABNF.OPCODE_PING)
 | 
				
			||||||
 | 
					        self.assertEqual(a.fin, 0)
 | 
				
			||||||
 | 
					        self.assertEqual(a.rsv1, 0)
 | 
				
			||||||
 | 
					        self.assertEqual(a.rsv2, 0)
 | 
				
			||||||
 | 
					        self.assertEqual(a.rsv3, 0)
 | 
				
			||||||
 | 
					        self.assertEqual(a.opcode, 9)
 | 
				
			||||||
 | 
					        self.assertEqual(a.data, "")
 | 
				
			||||||
 | 
					        a_bad = ABNF(0, 1, 0, 0, opcode=77)
 | 
				
			||||||
 | 
					        self.assertEqual(a_bad.rsv1, 1)
 | 
				
			||||||
 | 
					        self.assertEqual(a_bad.opcode, 77)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_validate(self):
 | 
				
			||||||
 | 
					        a_invalid_ping = ABNF(0, 0, 0, 0, opcode=ABNF.OPCODE_PING)
 | 
				
			||||||
 | 
					        self.assertRaises(
 | 
				
			||||||
 | 
					            WebSocketProtocolException,
 | 
				
			||||||
 | 
					            a_invalid_ping.validate,
 | 
				
			||||||
 | 
					            skip_utf8_validation=False,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        a_bad_rsv_value = ABNF(0, 1, 0, 0, opcode=ABNF.OPCODE_TEXT)
 | 
				
			||||||
 | 
					        self.assertRaises(
 | 
				
			||||||
 | 
					            WebSocketProtocolException,
 | 
				
			||||||
 | 
					            a_bad_rsv_value.validate,
 | 
				
			||||||
 | 
					            skip_utf8_validation=False,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        a_bad_opcode = ABNF(0, 0, 0, 0, opcode=77)
 | 
				
			||||||
 | 
					        self.assertRaises(
 | 
				
			||||||
 | 
					            WebSocketProtocolException,
 | 
				
			||||||
 | 
					            a_bad_opcode.validate,
 | 
				
			||||||
 | 
					            skip_utf8_validation=False,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        a_bad_close_frame = ABNF(0, 0, 0, 0, opcode=ABNF.OPCODE_CLOSE, data=b"\x01")
 | 
				
			||||||
 | 
					        self.assertRaises(
 | 
				
			||||||
 | 
					            WebSocketProtocolException,
 | 
				
			||||||
 | 
					            a_bad_close_frame.validate,
 | 
				
			||||||
 | 
					            skip_utf8_validation=False,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        a_bad_close_frame_2 = ABNF(
 | 
				
			||||||
 | 
					            0, 0, 0, 0, opcode=ABNF.OPCODE_CLOSE, data=b"\x01\x8a\xaa\xff\xdd"
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertRaises(
 | 
				
			||||||
 | 
					            WebSocketProtocolException,
 | 
				
			||||||
 | 
					            a_bad_close_frame_2.validate,
 | 
				
			||||||
 | 
					            skip_utf8_validation=False,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        a_bad_close_frame_3 = ABNF(
 | 
				
			||||||
 | 
					            0, 0, 0, 0, opcode=ABNF.OPCODE_CLOSE, data=b"\x03\xe7"
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertRaises(
 | 
				
			||||||
 | 
					            WebSocketProtocolException,
 | 
				
			||||||
 | 
					            a_bad_close_frame_3.validate,
 | 
				
			||||||
 | 
					            skip_utf8_validation=True,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_mask(self):
 | 
				
			||||||
 | 
					        abnf_none_data = ABNF(
 | 
				
			||||||
 | 
					            0, 0, 0, 0, opcode=ABNF.OPCODE_PING, mask_value=1, data=None
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        bytes_val = b"aaaa"
 | 
				
			||||||
 | 
					        self.assertEqual(abnf_none_data._get_masked(bytes_val), bytes_val)
 | 
				
			||||||
 | 
					        abnf_str_data = ABNF(
 | 
				
			||||||
 | 
					            0, 0, 0, 0, opcode=ABNF.OPCODE_PING, mask_value=1, data="a"
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertEqual(abnf_str_data._get_masked(bytes_val), b"aaaa\x00")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_format(self):
 | 
				
			||||||
 | 
					        abnf_bad_rsv_bits = ABNF(2, 0, 0, 0, opcode=ABNF.OPCODE_TEXT)
 | 
				
			||||||
 | 
					        self.assertRaises(ValueError, abnf_bad_rsv_bits.format)
 | 
				
			||||||
 | 
					        abnf_bad_opcode = ABNF(0, 0, 0, 0, opcode=5)
 | 
				
			||||||
 | 
					        self.assertRaises(ValueError, abnf_bad_opcode.format)
 | 
				
			||||||
 | 
					        abnf_length_10 = ABNF(0, 0, 0, 0, opcode=ABNF.OPCODE_TEXT, data="abcdefghij")
 | 
				
			||||||
 | 
					        self.assertEqual(b"\x01", abnf_length_10.format()[0].to_bytes(1, "big"))
 | 
				
			||||||
 | 
					        self.assertEqual(b"\x8a", abnf_length_10.format()[1].to_bytes(1, "big"))
 | 
				
			||||||
 | 
					        self.assertEqual("fin=0 opcode=1 data=abcdefghij", abnf_length_10.__str__())
 | 
				
			||||||
 | 
					        abnf_length_20 = ABNF(
 | 
				
			||||||
 | 
					            0, 0, 0, 0, opcode=ABNF.OPCODE_BINARY, data="abcdefghijabcdefghij"
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertEqual(b"\x02", abnf_length_20.format()[0].to_bytes(1, "big"))
 | 
				
			||||||
 | 
					        self.assertEqual(b"\x94", abnf_length_20.format()[1].to_bytes(1, "big"))
 | 
				
			||||||
 | 
					        abnf_no_mask = ABNF(
 | 
				
			||||||
 | 
					            0, 0, 0, 0, opcode=ABNF.OPCODE_TEXT, mask_value=0, data=b"\x01\x8a\xcc"
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertEqual(b"\x01\x03\x01\x8a\xcc", abnf_no_mask.format())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_frame_buffer(self):
 | 
				
			||||||
 | 
					        fb = frame_buffer(0, True)
 | 
				
			||||||
 | 
					        self.assertEqual(fb.recv, 0)
 | 
				
			||||||
 | 
					        self.assertEqual(fb.skip_utf8_validation, True)
 | 
				
			||||||
 | 
					        fb.clear
 | 
				
			||||||
 | 
					        self.assertEqual(fb.header, None)
 | 
				
			||||||
 | 
					        self.assertEqual(fb.length, None)
 | 
				
			||||||
 | 
					        self.assertEqual(fb.mask_value, None)
 | 
				
			||||||
 | 
					        self.assertEqual(fb.has_mask(), False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
					    unittest.main()
 | 
				
			||||||
							
								
								
									
										352
									
								
								src/libs/websocket/tests/test_app.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										352
									
								
								src/libs/websocket/tests/test_app.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,352 @@
 | 
				
			|||||||
 | 
					# -*- coding: utf-8 -*-
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					import os.path
 | 
				
			||||||
 | 
					import ssl
 | 
				
			||||||
 | 
					import threading
 | 
				
			||||||
 | 
					import unittest
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import websocket as ws
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					test_app.py
 | 
				
			||||||
 | 
					websocket - WebSocket client library for Python
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Copyright 2024 engn33r
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					You may obtain a copy of the License at
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					limitations under the License.
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Skip test to access the internet unless TEST_WITH_INTERNET == 1
 | 
				
			||||||
 | 
					TEST_WITH_INTERNET = os.environ.get("TEST_WITH_INTERNET", "0") == "1"
 | 
				
			||||||
 | 
					# Skip tests relying on local websockets server unless LOCAL_WS_SERVER_PORT != -1
 | 
				
			||||||
 | 
					LOCAL_WS_SERVER_PORT = os.environ.get("LOCAL_WS_SERVER_PORT", "-1")
 | 
				
			||||||
 | 
					TEST_WITH_LOCAL_SERVER = LOCAL_WS_SERVER_PORT != "-1"
 | 
				
			||||||
 | 
					TRACEABLE = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class WebSocketAppTest(unittest.TestCase):
 | 
				
			||||||
 | 
					    class NotSetYet:
 | 
				
			||||||
 | 
					        """A marker class for signalling that a value hasn't been set yet."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def setUp(self):
 | 
				
			||||||
 | 
					        ws.enableTrace(TRACEABLE)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        WebSocketAppTest.keep_running_open = WebSocketAppTest.NotSetYet()
 | 
				
			||||||
 | 
					        WebSocketAppTest.keep_running_close = WebSocketAppTest.NotSetYet()
 | 
				
			||||||
 | 
					        WebSocketAppTest.get_mask_key_id = WebSocketAppTest.NotSetYet()
 | 
				
			||||||
 | 
					        WebSocketAppTest.on_error_data = WebSocketAppTest.NotSetYet()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def tearDown(self):
 | 
				
			||||||
 | 
					        WebSocketAppTest.keep_running_open = WebSocketAppTest.NotSetYet()
 | 
				
			||||||
 | 
					        WebSocketAppTest.keep_running_close = WebSocketAppTest.NotSetYet()
 | 
				
			||||||
 | 
					        WebSocketAppTest.get_mask_key_id = WebSocketAppTest.NotSetYet()
 | 
				
			||||||
 | 
					        WebSocketAppTest.on_error_data = WebSocketAppTest.NotSetYet()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def close(self):
 | 
				
			||||||
 | 
					        pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @unittest.skipUnless(
 | 
				
			||||||
 | 
					        TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    def test_keep_running(self):
 | 
				
			||||||
 | 
					        """A WebSocketApp should keep running as long as its self.keep_running
 | 
				
			||||||
 | 
					        is not False (in the boolean context).
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        def on_open(self, *args, **kwargs):
 | 
				
			||||||
 | 
					            """Set the keep_running flag for later inspection and immediately
 | 
				
			||||||
 | 
					            close the connection.
 | 
				
			||||||
 | 
					            """
 | 
				
			||||||
 | 
					            self.send("hello!")
 | 
				
			||||||
 | 
					            WebSocketAppTest.keep_running_open = self.keep_running
 | 
				
			||||||
 | 
					            self.keep_running = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        def on_message(_, message):
 | 
				
			||||||
 | 
					            print(message)
 | 
				
			||||||
 | 
					            self.close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        def on_close(self, *args, **kwargs):
 | 
				
			||||||
 | 
					            """Set the keep_running flag for the test to use."""
 | 
				
			||||||
 | 
					            WebSocketAppTest.keep_running_close = self.keep_running
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        app = ws.WebSocketApp(
 | 
				
			||||||
 | 
					            f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}",
 | 
				
			||||||
 | 
					            on_open=on_open,
 | 
				
			||||||
 | 
					            on_close=on_close,
 | 
				
			||||||
 | 
					            on_message=on_message,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        app.run_forever()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    #    @unittest.skipUnless(TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled")
 | 
				
			||||||
 | 
					    @unittest.skipUnless(False, "Test disabled for now (requires rel)")
 | 
				
			||||||
 | 
					    def test_run_forever_dispatcher(self):
 | 
				
			||||||
 | 
					        """A WebSocketApp should keep running as long as its self.keep_running
 | 
				
			||||||
 | 
					        is not False (in the boolean context).
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        def on_open(self, *args, **kwargs):
 | 
				
			||||||
 | 
					            """Send a message, receive, and send one more"""
 | 
				
			||||||
 | 
					            self.send("hello!")
 | 
				
			||||||
 | 
					            self.recv()
 | 
				
			||||||
 | 
					            self.send("goodbye!")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        def on_message(_, message):
 | 
				
			||||||
 | 
					            print(message)
 | 
				
			||||||
 | 
					            self.close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        app = ws.WebSocketApp(
 | 
				
			||||||
 | 
					            f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}",
 | 
				
			||||||
 | 
					            on_open=on_open,
 | 
				
			||||||
 | 
					            on_message=on_message,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        app.run_forever(dispatcher="Dispatcher")  # doesn't work
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    #        app.run_forever(dispatcher=rel)          # would work
 | 
				
			||||||
 | 
					    #        rel.dispatch()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @unittest.skipUnless(
 | 
				
			||||||
 | 
					        TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    def test_run_forever_teardown_clean_exit(self):
 | 
				
			||||||
 | 
					        """The WebSocketApp.run_forever() method should return `False` when the application ends gracefully."""
 | 
				
			||||||
 | 
					        app = ws.WebSocketApp(f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}")
 | 
				
			||||||
 | 
					        threading.Timer(interval=0.2, function=app.close).start()
 | 
				
			||||||
 | 
					        teardown = app.run_forever()
 | 
				
			||||||
 | 
					        self.assertEqual(teardown, False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
 | 
				
			||||||
 | 
					    def test_sock_mask_key(self):
 | 
				
			||||||
 | 
					        """A WebSocketApp should forward the received mask_key function down
 | 
				
			||||||
 | 
					        to the actual socket.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        def my_mask_key_func():
 | 
				
			||||||
 | 
					            return "\x00\x00\x00\x00"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        app = ws.WebSocketApp(
 | 
				
			||||||
 | 
					            "wss://api-pub.bitfinex.com/ws/1", get_mask_key=my_mask_key_func
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # if numpy is installed, this assertion fail
 | 
				
			||||||
 | 
					        # Note: We can't use 'is' for comparing the functions directly, need to use 'id'.
 | 
				
			||||||
 | 
					        self.assertEqual(id(app.get_mask_key), id(my_mask_key_func))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
 | 
				
			||||||
 | 
					    def test_invalid_ping_interval_ping_timeout(self):
 | 
				
			||||||
 | 
					        """Test exception handling if ping_interval < ping_timeout"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        def on_ping(app, _):
 | 
				
			||||||
 | 
					            print("Got a ping!")
 | 
				
			||||||
 | 
					            app.close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        def on_pong(app, _):
 | 
				
			||||||
 | 
					            print("Got a pong! No need to respond")
 | 
				
			||||||
 | 
					            app.close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        app = ws.WebSocketApp(
 | 
				
			||||||
 | 
					            "wss://api-pub.bitfinex.com/ws/1", on_ping=on_ping, on_pong=on_pong
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertRaises(
 | 
				
			||||||
 | 
					            ws.WebSocketException,
 | 
				
			||||||
 | 
					            app.run_forever,
 | 
				
			||||||
 | 
					            ping_interval=1,
 | 
				
			||||||
 | 
					            ping_timeout=2,
 | 
				
			||||||
 | 
					            sslopt={"cert_reqs": ssl.CERT_NONE},
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
 | 
				
			||||||
 | 
					    def test_ping_interval(self):
 | 
				
			||||||
 | 
					        """Test WebSocketApp proper ping functionality"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        def on_ping(app, _):
 | 
				
			||||||
 | 
					            print("Got a ping!")
 | 
				
			||||||
 | 
					            app.close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        def on_pong(app, _):
 | 
				
			||||||
 | 
					            print("Got a pong! No need to respond")
 | 
				
			||||||
 | 
					            app.close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        app = ws.WebSocketApp(
 | 
				
			||||||
 | 
					            "wss://api-pub.bitfinex.com/ws/1", on_ping=on_ping, on_pong=on_pong
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        app.run_forever(
 | 
				
			||||||
 | 
					            ping_interval=2, ping_timeout=1, sslopt={"cert_reqs": ssl.CERT_NONE}
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
 | 
				
			||||||
 | 
					    def test_opcode_close(self):
 | 
				
			||||||
 | 
					        """Test WebSocketApp close opcode"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        app = ws.WebSocketApp("wss://tsock.us1.twilio.com/v3/wsconnect")
 | 
				
			||||||
 | 
					        app.run_forever(ping_interval=2, ping_timeout=1, ping_payload="Ping payload")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # This is commented out because the URL no longer responds in the expected way
 | 
				
			||||||
 | 
					    # @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
 | 
				
			||||||
 | 
					    # def testOpcodeBinary(self):
 | 
				
			||||||
 | 
					    #     """ Test WebSocketApp binary opcode
 | 
				
			||||||
 | 
					    #     """
 | 
				
			||||||
 | 
					    #     app = ws.WebSocketApp('wss://streaming.vn.teslamotors.com/streaming/')
 | 
				
			||||||
 | 
					    #     app.run_forever(ping_interval=2, ping_timeout=1, ping_payload="Ping payload")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
 | 
				
			||||||
 | 
					    def test_bad_ping_interval(self):
 | 
				
			||||||
 | 
					        """A WebSocketApp handling of negative ping_interval"""
 | 
				
			||||||
 | 
					        app = ws.WebSocketApp("wss://api-pub.bitfinex.com/ws/1")
 | 
				
			||||||
 | 
					        self.assertRaises(
 | 
				
			||||||
 | 
					            ws.WebSocketException,
 | 
				
			||||||
 | 
					            app.run_forever,
 | 
				
			||||||
 | 
					            ping_interval=-5,
 | 
				
			||||||
 | 
					            sslopt={"cert_reqs": ssl.CERT_NONE},
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
 | 
				
			||||||
 | 
					    def test_bad_ping_timeout(self):
 | 
				
			||||||
 | 
					        """A WebSocketApp handling of negative ping_timeout"""
 | 
				
			||||||
 | 
					        app = ws.WebSocketApp("wss://api-pub.bitfinex.com/ws/1")
 | 
				
			||||||
 | 
					        self.assertRaises(
 | 
				
			||||||
 | 
					            ws.WebSocketException,
 | 
				
			||||||
 | 
					            app.run_forever,
 | 
				
			||||||
 | 
					            ping_timeout=-3,
 | 
				
			||||||
 | 
					            sslopt={"cert_reqs": ssl.CERT_NONE},
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
 | 
				
			||||||
 | 
					    def test_close_status_code(self):
 | 
				
			||||||
 | 
					        """Test extraction of close frame status code and close reason in WebSocketApp"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        def on_close(wsapp, close_status_code, close_msg):
 | 
				
			||||||
 | 
					            print("on_close reached")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        app = ws.WebSocketApp(
 | 
				
			||||||
 | 
					            "wss://tsock.us1.twilio.com/v3/wsconnect", on_close=on_close
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        closeframe = ws.ABNF(
 | 
				
			||||||
 | 
					            opcode=ws.ABNF.OPCODE_CLOSE, data=b"\x03\xe8no-init-from-client"
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertEqual([1000, "no-init-from-client"], app._get_close_args(closeframe))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        closeframe = ws.ABNF(opcode=ws.ABNF.OPCODE_CLOSE, data=b"")
 | 
				
			||||||
 | 
					        self.assertEqual([None, None], app._get_close_args(closeframe))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        app2 = ws.WebSocketApp("wss://tsock.us1.twilio.com/v3/wsconnect")
 | 
				
			||||||
 | 
					        closeframe = ws.ABNF(opcode=ws.ABNF.OPCODE_CLOSE, data=b"")
 | 
				
			||||||
 | 
					        self.assertEqual([None, None], app2._get_close_args(closeframe))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.assertRaises(
 | 
				
			||||||
 | 
					            ws.WebSocketConnectionClosedException,
 | 
				
			||||||
 | 
					            app.send,
 | 
				
			||||||
 | 
					            data="test if connection is closed",
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @unittest.skipUnless(
 | 
				
			||||||
 | 
					        TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    def test_callback_function_exception(self):
 | 
				
			||||||
 | 
					        """Test callback function exception handling"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        exc = None
 | 
				
			||||||
 | 
					        passed_app = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        def on_open(app):
 | 
				
			||||||
 | 
					            raise RuntimeError("Callback failed")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        def on_error(app, err):
 | 
				
			||||||
 | 
					            nonlocal passed_app
 | 
				
			||||||
 | 
					            passed_app = app
 | 
				
			||||||
 | 
					            nonlocal exc
 | 
				
			||||||
 | 
					            exc = err
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        def on_pong(app, _):
 | 
				
			||||||
 | 
					            app.close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        app = ws.WebSocketApp(
 | 
				
			||||||
 | 
					            f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}",
 | 
				
			||||||
 | 
					            on_open=on_open,
 | 
				
			||||||
 | 
					            on_error=on_error,
 | 
				
			||||||
 | 
					            on_pong=on_pong,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        app.run_forever(ping_interval=2, ping_timeout=1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.assertEqual(passed_app, app)
 | 
				
			||||||
 | 
					        self.assertIsInstance(exc, RuntimeError)
 | 
				
			||||||
 | 
					        self.assertEqual(str(exc), "Callback failed")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @unittest.skipUnless(
 | 
				
			||||||
 | 
					        TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    def test_callback_method_exception(self):
 | 
				
			||||||
 | 
					        """Test callback method exception handling"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        class Callbacks:
 | 
				
			||||||
 | 
					            def __init__(self):
 | 
				
			||||||
 | 
					                self.exc = None
 | 
				
			||||||
 | 
					                self.passed_app = None
 | 
				
			||||||
 | 
					                self.app = ws.WebSocketApp(
 | 
				
			||||||
 | 
					                    f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}",
 | 
				
			||||||
 | 
					                    on_open=self.on_open,
 | 
				
			||||||
 | 
					                    on_error=self.on_error,
 | 
				
			||||||
 | 
					                    on_pong=self.on_pong,
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					                self.app.run_forever(ping_interval=2, ping_timeout=1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            def on_open(self, _):
 | 
				
			||||||
 | 
					                raise RuntimeError("Callback failed")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            def on_error(self, app, err):
 | 
				
			||||||
 | 
					                self.passed_app = app
 | 
				
			||||||
 | 
					                self.exc = err
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            def on_pong(self, app, _):
 | 
				
			||||||
 | 
					                app.close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        callbacks = Callbacks()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.assertEqual(callbacks.passed_app, callbacks.app)
 | 
				
			||||||
 | 
					        self.assertIsInstance(callbacks.exc, RuntimeError)
 | 
				
			||||||
 | 
					        self.assertEqual(str(callbacks.exc), "Callback failed")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @unittest.skipUnless(
 | 
				
			||||||
 | 
					        TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    def test_reconnect(self):
 | 
				
			||||||
 | 
					        """Test reconnect"""
 | 
				
			||||||
 | 
					        pong_count = 0
 | 
				
			||||||
 | 
					        exc = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        def on_error(_, err):
 | 
				
			||||||
 | 
					            nonlocal exc
 | 
				
			||||||
 | 
					            exc = err
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        def on_pong(app, _):
 | 
				
			||||||
 | 
					            nonlocal pong_count
 | 
				
			||||||
 | 
					            pong_count += 1
 | 
				
			||||||
 | 
					            if pong_count == 1:
 | 
				
			||||||
 | 
					                # First pong, shutdown socket, enforce read error
 | 
				
			||||||
 | 
					                app.sock.shutdown()
 | 
				
			||||||
 | 
					            if pong_count >= 2:
 | 
				
			||||||
 | 
					                # Got second pong after reconnect
 | 
				
			||||||
 | 
					                app.close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        app = ws.WebSocketApp(
 | 
				
			||||||
 | 
					            f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}", on_pong=on_pong, on_error=on_error
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        app.run_forever(ping_interval=2, ping_timeout=1, reconnect=3)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.assertEqual(pong_count, 2)
 | 
				
			||||||
 | 
					        self.assertIsInstance(exc, ws.WebSocketTimeoutException)
 | 
				
			||||||
 | 
					        self.assertEqual(str(exc), "ping/pong timed out")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
					    unittest.main()
 | 
				
			||||||
							
								
								
									
										123
									
								
								src/libs/websocket/tests/test_cookiejar.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										123
									
								
								src/libs/websocket/tests/test_cookiejar.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,123 @@
 | 
				
			|||||||
 | 
					import unittest
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from websocket._cookiejar import SimpleCookieJar
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					test_cookiejar.py
 | 
				
			||||||
 | 
					websocket - WebSocket client library for Python
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Copyright 2024 engn33r
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					You may obtain a copy of the License at
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					limitations under the License.
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class CookieJarTest(unittest.TestCase):
 | 
				
			||||||
 | 
					    def test_add(self):
 | 
				
			||||||
 | 
					        cookie_jar = SimpleCookieJar()
 | 
				
			||||||
 | 
					        cookie_jar.add("")
 | 
				
			||||||
 | 
					        self.assertFalse(
 | 
				
			||||||
 | 
					            cookie_jar.jar, "Cookie with no domain should not be added to the jar"
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        cookie_jar = SimpleCookieJar()
 | 
				
			||||||
 | 
					        cookie_jar.add("a=b")
 | 
				
			||||||
 | 
					        self.assertFalse(
 | 
				
			||||||
 | 
					            cookie_jar.jar, "Cookie with no domain should not be added to the jar"
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        cookie_jar = SimpleCookieJar()
 | 
				
			||||||
 | 
					        cookie_jar.add("a=b; domain=.abc")
 | 
				
			||||||
 | 
					        self.assertTrue(".abc" in cookie_jar.jar)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        cookie_jar = SimpleCookieJar()
 | 
				
			||||||
 | 
					        cookie_jar.add("a=b; domain=abc")
 | 
				
			||||||
 | 
					        self.assertTrue(".abc" in cookie_jar.jar)
 | 
				
			||||||
 | 
					        self.assertTrue("abc" not in cookie_jar.jar)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        cookie_jar = SimpleCookieJar()
 | 
				
			||||||
 | 
					        cookie_jar.add("a=b; c=d; domain=abc")
 | 
				
			||||||
 | 
					        self.assertEqual(cookie_jar.get("abc"), "a=b; c=d")
 | 
				
			||||||
 | 
					        self.assertEqual(cookie_jar.get(None), "")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        cookie_jar = SimpleCookieJar()
 | 
				
			||||||
 | 
					        cookie_jar.add("a=b; c=d; domain=abc")
 | 
				
			||||||
 | 
					        cookie_jar.add("e=f; domain=abc")
 | 
				
			||||||
 | 
					        self.assertEqual(cookie_jar.get("abc"), "a=b; c=d; e=f")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        cookie_jar = SimpleCookieJar()
 | 
				
			||||||
 | 
					        cookie_jar.add("a=b; c=d; domain=abc")
 | 
				
			||||||
 | 
					        cookie_jar.add("e=f; domain=.abc")
 | 
				
			||||||
 | 
					        self.assertEqual(cookie_jar.get("abc"), "a=b; c=d; e=f")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        cookie_jar = SimpleCookieJar()
 | 
				
			||||||
 | 
					        cookie_jar.add("a=b; c=d; domain=abc")
 | 
				
			||||||
 | 
					        cookie_jar.add("e=f; domain=xyz")
 | 
				
			||||||
 | 
					        self.assertEqual(cookie_jar.get("abc"), "a=b; c=d")
 | 
				
			||||||
 | 
					        self.assertEqual(cookie_jar.get("xyz"), "e=f")
 | 
				
			||||||
 | 
					        self.assertEqual(cookie_jar.get("something"), "")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_set(self):
 | 
				
			||||||
 | 
					        cookie_jar = SimpleCookieJar()
 | 
				
			||||||
 | 
					        cookie_jar.set("a=b")
 | 
				
			||||||
 | 
					        self.assertFalse(
 | 
				
			||||||
 | 
					            cookie_jar.jar, "Cookie with no domain should not be added to the jar"
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        cookie_jar = SimpleCookieJar()
 | 
				
			||||||
 | 
					        cookie_jar.set("a=b; domain=.abc")
 | 
				
			||||||
 | 
					        self.assertTrue(".abc" in cookie_jar.jar)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        cookie_jar = SimpleCookieJar()
 | 
				
			||||||
 | 
					        cookie_jar.set("a=b; domain=abc")
 | 
				
			||||||
 | 
					        self.assertTrue(".abc" in cookie_jar.jar)
 | 
				
			||||||
 | 
					        self.assertTrue("abc" not in cookie_jar.jar)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        cookie_jar = SimpleCookieJar()
 | 
				
			||||||
 | 
					        cookie_jar.set("a=b; c=d; domain=abc")
 | 
				
			||||||
 | 
					        self.assertEqual(cookie_jar.get("abc"), "a=b; c=d")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        cookie_jar = SimpleCookieJar()
 | 
				
			||||||
 | 
					        cookie_jar.set("a=b; c=d; domain=abc")
 | 
				
			||||||
 | 
					        cookie_jar.set("e=f; domain=abc")
 | 
				
			||||||
 | 
					        self.assertEqual(cookie_jar.get("abc"), "e=f")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        cookie_jar = SimpleCookieJar()
 | 
				
			||||||
 | 
					        cookie_jar.set("a=b; c=d; domain=abc")
 | 
				
			||||||
 | 
					        cookie_jar.set("e=f; domain=.abc")
 | 
				
			||||||
 | 
					        self.assertEqual(cookie_jar.get("abc"), "e=f")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        cookie_jar = SimpleCookieJar()
 | 
				
			||||||
 | 
					        cookie_jar.set("a=b; c=d; domain=abc")
 | 
				
			||||||
 | 
					        cookie_jar.set("e=f; domain=xyz")
 | 
				
			||||||
 | 
					        self.assertEqual(cookie_jar.get("abc"), "a=b; c=d")
 | 
				
			||||||
 | 
					        self.assertEqual(cookie_jar.get("xyz"), "e=f")
 | 
				
			||||||
 | 
					        self.assertEqual(cookie_jar.get("something"), "")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_get(self):
 | 
				
			||||||
 | 
					        cookie_jar = SimpleCookieJar()
 | 
				
			||||||
 | 
					        cookie_jar.set("a=b; c=d; domain=abc.com")
 | 
				
			||||||
 | 
					        self.assertEqual(cookie_jar.get("abc.com"), "a=b; c=d")
 | 
				
			||||||
 | 
					        self.assertEqual(cookie_jar.get("x.abc.com"), "a=b; c=d")
 | 
				
			||||||
 | 
					        self.assertEqual(cookie_jar.get("abc.com.es"), "")
 | 
				
			||||||
 | 
					        self.assertEqual(cookie_jar.get("xabc.com"), "")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        cookie_jar.set("a=b; c=d; domain=.abc.com")
 | 
				
			||||||
 | 
					        self.assertEqual(cookie_jar.get("abc.com"), "a=b; c=d")
 | 
				
			||||||
 | 
					        self.assertEqual(cookie_jar.get("x.abc.com"), "a=b; c=d")
 | 
				
			||||||
 | 
					        self.assertEqual(cookie_jar.get("abc.com.es"), "")
 | 
				
			||||||
 | 
					        self.assertEqual(cookie_jar.get("xabc.com"), "")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
					    unittest.main()
 | 
				
			||||||
							
								
								
									
										370
									
								
								src/libs/websocket/tests/test_http.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										370
									
								
								src/libs/websocket/tests/test_http.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,370 @@
 | 
				
			|||||||
 | 
					# -*- coding: utf-8 -*-
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					import os.path
 | 
				
			||||||
 | 
					import socket
 | 
				
			||||||
 | 
					import ssl
 | 
				
			||||||
 | 
					import unittest
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import websocket
 | 
				
			||||||
 | 
					from websocket._exceptions import WebSocketProxyException, WebSocketException
 | 
				
			||||||
 | 
					from websocket._http import (
 | 
				
			||||||
 | 
					    _get_addrinfo_list,
 | 
				
			||||||
 | 
					    _start_proxied_socket,
 | 
				
			||||||
 | 
					    _tunnel,
 | 
				
			||||||
 | 
					    connect,
 | 
				
			||||||
 | 
					    proxy_info,
 | 
				
			||||||
 | 
					    read_headers,
 | 
				
			||||||
 | 
					    HAVE_PYTHON_SOCKS,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					test_http.py
 | 
				
			||||||
 | 
					websocket - WebSocket client library for Python
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Copyright 2024 engn33r
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					You may obtain a copy of the License at
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					limitations under the License.
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					try:
 | 
				
			||||||
 | 
					    from python_socks._errors import ProxyConnectionError, ProxyError, ProxyTimeoutError
 | 
				
			||||||
 | 
					except:
 | 
				
			||||||
 | 
					    from websocket._http import ProxyConnectionError, ProxyError, ProxyTimeoutError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Skip test to access the internet unless TEST_WITH_INTERNET == 1
 | 
				
			||||||
 | 
					TEST_WITH_INTERNET = os.environ.get("TEST_WITH_INTERNET", "0") == "1"
 | 
				
			||||||
 | 
					TEST_WITH_PROXY = os.environ.get("TEST_WITH_PROXY", "0") == "1"
 | 
				
			||||||
 | 
					# Skip tests relying on local websockets server unless LOCAL_WS_SERVER_PORT != -1
 | 
				
			||||||
 | 
					LOCAL_WS_SERVER_PORT = os.environ.get("LOCAL_WS_SERVER_PORT", "-1")
 | 
				
			||||||
 | 
					TEST_WITH_LOCAL_SERVER = LOCAL_WS_SERVER_PORT != "-1"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class SockMock:
 | 
				
			||||||
 | 
					    def __init__(self):
 | 
				
			||||||
 | 
					        self.data = []
 | 
				
			||||||
 | 
					        self.sent = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def add_packet(self, data):
 | 
				
			||||||
 | 
					        self.data.append(data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def gettimeout(self):
 | 
				
			||||||
 | 
					        return None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def recv(self, bufsize):
 | 
				
			||||||
 | 
					        if self.data:
 | 
				
			||||||
 | 
					            e = self.data.pop(0)
 | 
				
			||||||
 | 
					            if isinstance(e, Exception):
 | 
				
			||||||
 | 
					                raise e
 | 
				
			||||||
 | 
					            if len(e) > bufsize:
 | 
				
			||||||
 | 
					                self.data.insert(0, e[bufsize:])
 | 
				
			||||||
 | 
					            return e[:bufsize]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def send(self, data):
 | 
				
			||||||
 | 
					        self.sent.append(data)
 | 
				
			||||||
 | 
					        return len(data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def close(self):
 | 
				
			||||||
 | 
					        pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class HeaderSockMock(SockMock):
 | 
				
			||||||
 | 
					    def __init__(self, fname):
 | 
				
			||||||
 | 
					        SockMock.__init__(self)
 | 
				
			||||||
 | 
					        path = os.path.join(os.path.dirname(__file__), fname)
 | 
				
			||||||
 | 
					        with open(path, "rb") as f:
 | 
				
			||||||
 | 
					            self.add_packet(f.read())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class OptsList:
 | 
				
			||||||
 | 
					    def __init__(self):
 | 
				
			||||||
 | 
					        self.timeout = 1
 | 
				
			||||||
 | 
					        self.sockopt = []
 | 
				
			||||||
 | 
					        self.sslopt = {"cert_reqs": ssl.CERT_NONE}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class HttpTest(unittest.TestCase):
 | 
				
			||||||
 | 
					    def test_read_header(self):
 | 
				
			||||||
 | 
					        status, header, _ = read_headers(HeaderSockMock("data/header01.txt"))
 | 
				
			||||||
 | 
					        self.assertEqual(status, 101)
 | 
				
			||||||
 | 
					        self.assertEqual(header["connection"], "Upgrade")
 | 
				
			||||||
 | 
					        # header02.txt is intentionally malformed
 | 
				
			||||||
 | 
					        self.assertRaises(
 | 
				
			||||||
 | 
					            WebSocketException, read_headers, HeaderSockMock("data/header02.txt")
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_tunnel(self):
 | 
				
			||||||
 | 
					        self.assertRaises(
 | 
				
			||||||
 | 
					            WebSocketProxyException,
 | 
				
			||||||
 | 
					            _tunnel,
 | 
				
			||||||
 | 
					            HeaderSockMock("data/header01.txt"),
 | 
				
			||||||
 | 
					            "example.com",
 | 
				
			||||||
 | 
					            80,
 | 
				
			||||||
 | 
					            ("username", "password"),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertRaises(
 | 
				
			||||||
 | 
					            WebSocketProxyException,
 | 
				
			||||||
 | 
					            _tunnel,
 | 
				
			||||||
 | 
					            HeaderSockMock("data/header02.txt"),
 | 
				
			||||||
 | 
					            "example.com",
 | 
				
			||||||
 | 
					            80,
 | 
				
			||||||
 | 
					            ("username", "password"),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
 | 
				
			||||||
 | 
					    def test_connect(self):
 | 
				
			||||||
 | 
					        # Not currently testing an actual proxy connection, so just check whether proxy errors are raised. This requires internet for a DNS lookup
 | 
				
			||||||
 | 
					        if HAVE_PYTHON_SOCKS:
 | 
				
			||||||
 | 
					            # Need this check, otherwise case where python_socks is not installed triggers
 | 
				
			||||||
 | 
					            # websocket._exceptions.WebSocketException: Python Socks is needed for SOCKS proxying but is not available
 | 
				
			||||||
 | 
					            self.assertRaises(
 | 
				
			||||||
 | 
					                (ProxyTimeoutError, OSError),
 | 
				
			||||||
 | 
					                _start_proxied_socket,
 | 
				
			||||||
 | 
					                "wss://example.com",
 | 
				
			||||||
 | 
					                OptsList(),
 | 
				
			||||||
 | 
					                proxy_info(
 | 
				
			||||||
 | 
					                    http_proxy_host="example.com",
 | 
				
			||||||
 | 
					                    http_proxy_port="8080",
 | 
				
			||||||
 | 
					                    proxy_type="socks4",
 | 
				
			||||||
 | 
					                    http_proxy_timeout=1,
 | 
				
			||||||
 | 
					                ),
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            self.assertRaises(
 | 
				
			||||||
 | 
					                (ProxyTimeoutError, OSError),
 | 
				
			||||||
 | 
					                _start_proxied_socket,
 | 
				
			||||||
 | 
					                "wss://example.com",
 | 
				
			||||||
 | 
					                OptsList(),
 | 
				
			||||||
 | 
					                proxy_info(
 | 
				
			||||||
 | 
					                    http_proxy_host="example.com",
 | 
				
			||||||
 | 
					                    http_proxy_port="8080",
 | 
				
			||||||
 | 
					                    proxy_type="socks4a",
 | 
				
			||||||
 | 
					                    http_proxy_timeout=1,
 | 
				
			||||||
 | 
					                ),
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            self.assertRaises(
 | 
				
			||||||
 | 
					                (ProxyTimeoutError, OSError),
 | 
				
			||||||
 | 
					                _start_proxied_socket,
 | 
				
			||||||
 | 
					                "wss://example.com",
 | 
				
			||||||
 | 
					                OptsList(),
 | 
				
			||||||
 | 
					                proxy_info(
 | 
				
			||||||
 | 
					                    http_proxy_host="example.com",
 | 
				
			||||||
 | 
					                    http_proxy_port="8080",
 | 
				
			||||||
 | 
					                    proxy_type="socks5",
 | 
				
			||||||
 | 
					                    http_proxy_timeout=1,
 | 
				
			||||||
 | 
					                ),
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            self.assertRaises(
 | 
				
			||||||
 | 
					                (ProxyTimeoutError, OSError),
 | 
				
			||||||
 | 
					                _start_proxied_socket,
 | 
				
			||||||
 | 
					                "wss://example.com",
 | 
				
			||||||
 | 
					                OptsList(),
 | 
				
			||||||
 | 
					                proxy_info(
 | 
				
			||||||
 | 
					                    http_proxy_host="example.com",
 | 
				
			||||||
 | 
					                    http_proxy_port="8080",
 | 
				
			||||||
 | 
					                    proxy_type="socks5h",
 | 
				
			||||||
 | 
					                    http_proxy_timeout=1,
 | 
				
			||||||
 | 
					                ),
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            self.assertRaises(
 | 
				
			||||||
 | 
					                ProxyConnectionError,
 | 
				
			||||||
 | 
					                connect,
 | 
				
			||||||
 | 
					                "wss://example.com",
 | 
				
			||||||
 | 
					                OptsList(),
 | 
				
			||||||
 | 
					                proxy_info(
 | 
				
			||||||
 | 
					                    http_proxy_host="127.0.0.1",
 | 
				
			||||||
 | 
					                    http_proxy_port=9999,
 | 
				
			||||||
 | 
					                    proxy_type="socks4",
 | 
				
			||||||
 | 
					                    http_proxy_timeout=1,
 | 
				
			||||||
 | 
					                ),
 | 
				
			||||||
 | 
					                None,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.assertRaises(
 | 
				
			||||||
 | 
					            TypeError,
 | 
				
			||||||
 | 
					            _get_addrinfo_list,
 | 
				
			||||||
 | 
					            None,
 | 
				
			||||||
 | 
					            80,
 | 
				
			||||||
 | 
					            True,
 | 
				
			||||||
 | 
					            proxy_info(
 | 
				
			||||||
 | 
					                http_proxy_host="127.0.0.1", http_proxy_port="9999", proxy_type="http"
 | 
				
			||||||
 | 
					            ),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertRaises(
 | 
				
			||||||
 | 
					            TypeError,
 | 
				
			||||||
 | 
					            _get_addrinfo_list,
 | 
				
			||||||
 | 
					            None,
 | 
				
			||||||
 | 
					            80,
 | 
				
			||||||
 | 
					            True,
 | 
				
			||||||
 | 
					            proxy_info(
 | 
				
			||||||
 | 
					                http_proxy_host="127.0.0.1", http_proxy_port="9999", proxy_type="http"
 | 
				
			||||||
 | 
					            ),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertRaises(
 | 
				
			||||||
 | 
					            socket.timeout,
 | 
				
			||||||
 | 
					            connect,
 | 
				
			||||||
 | 
					            "wss://google.com",
 | 
				
			||||||
 | 
					            OptsList(),
 | 
				
			||||||
 | 
					            proxy_info(
 | 
				
			||||||
 | 
					                http_proxy_host="8.8.8.8",
 | 
				
			||||||
 | 
					                http_proxy_port=9999,
 | 
				
			||||||
 | 
					                proxy_type="http",
 | 
				
			||||||
 | 
					                http_proxy_timeout=1,
 | 
				
			||||||
 | 
					            ),
 | 
				
			||||||
 | 
					            None,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertEqual(
 | 
				
			||||||
 | 
					            connect(
 | 
				
			||||||
 | 
					                "wss://google.com",
 | 
				
			||||||
 | 
					                OptsList(),
 | 
				
			||||||
 | 
					                proxy_info(
 | 
				
			||||||
 | 
					                    http_proxy_host="8.8.8.8", http_proxy_port=8080, proxy_type="http"
 | 
				
			||||||
 | 
					                ),
 | 
				
			||||||
 | 
					                True,
 | 
				
			||||||
 | 
					            ),
 | 
				
			||||||
 | 
					            (True, ("google.com", 443, "/")),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        # The following test fails on Mac OS with a gaierror, not an OverflowError
 | 
				
			||||||
 | 
					        # self.assertRaises(OverflowError, connect, "wss://example.com", OptsList(), proxy_info(http_proxy_host="127.0.0.1", http_proxy_port=99999, proxy_type="socks4", timeout=2), False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
 | 
				
			||||||
 | 
					    @unittest.skipUnless(
 | 
				
			||||||
 | 
					        TEST_WITH_PROXY, "This test requires a HTTP proxy to be running on port 8899"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    @unittest.skipUnless(
 | 
				
			||||||
 | 
					        TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    def test_proxy_connect(self):
 | 
				
			||||||
 | 
					        ws = websocket.WebSocket()
 | 
				
			||||||
 | 
					        ws.connect(
 | 
				
			||||||
 | 
					            f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}",
 | 
				
			||||||
 | 
					            http_proxy_host="127.0.0.1",
 | 
				
			||||||
 | 
					            http_proxy_port="8899",
 | 
				
			||||||
 | 
					            proxy_type="http",
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        ws.send("Hello, Server")
 | 
				
			||||||
 | 
					        server_response = ws.recv()
 | 
				
			||||||
 | 
					        self.assertEqual(server_response, "Hello, Server")
 | 
				
			||||||
 | 
					        # self.assertEqual(_start_proxied_socket("wss://api.bitfinex.com/ws/2", OptsList(), proxy_info(http_proxy_host="127.0.0.1", http_proxy_port="8899", proxy_type="http"))[1], ("api.bitfinex.com", 443, '/ws/2'))
 | 
				
			||||||
 | 
					        self.assertEqual(
 | 
				
			||||||
 | 
					            _get_addrinfo_list(
 | 
				
			||||||
 | 
					                "api.bitfinex.com",
 | 
				
			||||||
 | 
					                443,
 | 
				
			||||||
 | 
					                True,
 | 
				
			||||||
 | 
					                proxy_info(
 | 
				
			||||||
 | 
					                    http_proxy_host="127.0.0.1",
 | 
				
			||||||
 | 
					                    http_proxy_port="8899",
 | 
				
			||||||
 | 
					                    proxy_type="http",
 | 
				
			||||||
 | 
					                ),
 | 
				
			||||||
 | 
					            ),
 | 
				
			||||||
 | 
					            (
 | 
				
			||||||
 | 
					                socket.getaddrinfo(
 | 
				
			||||||
 | 
					                    "127.0.0.1", 8899, 0, socket.SOCK_STREAM, socket.SOL_TCP
 | 
				
			||||||
 | 
					                ),
 | 
				
			||||||
 | 
					                True,
 | 
				
			||||||
 | 
					                None,
 | 
				
			||||||
 | 
					            ),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertEqual(
 | 
				
			||||||
 | 
					            connect(
 | 
				
			||||||
 | 
					                "wss://api.bitfinex.com/ws/2",
 | 
				
			||||||
 | 
					                OptsList(),
 | 
				
			||||||
 | 
					                proxy_info(
 | 
				
			||||||
 | 
					                    http_proxy_host="127.0.0.1", http_proxy_port=8899, proxy_type="http"
 | 
				
			||||||
 | 
					                ),
 | 
				
			||||||
 | 
					                None,
 | 
				
			||||||
 | 
					            )[1],
 | 
				
			||||||
 | 
					            ("api.bitfinex.com", 443, "/ws/2"),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        # TODO: Test SOCKS4 and SOCK5 proxies with unit tests
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
 | 
				
			||||||
 | 
					    def test_sslopt(self):
 | 
				
			||||||
 | 
					        ssloptions = {
 | 
				
			||||||
 | 
					            "check_hostname": False,
 | 
				
			||||||
 | 
					            "server_hostname": "ServerName",
 | 
				
			||||||
 | 
					            "ssl_version": ssl.PROTOCOL_TLS_CLIENT,
 | 
				
			||||||
 | 
					            "ciphers": "TLS_AES_256_GCM_SHA384:TLS_CHACHA20_POLY1305_SHA256:\
 | 
				
			||||||
 | 
					                        TLS_AES_128_GCM_SHA256:ECDHE-ECDSA-AES256-GCM-SHA384:\
 | 
				
			||||||
 | 
					                        ECDHE-RSA-AES256-GCM-SHA384:DHE-RSA-AES256-GCM-SHA384:\
 | 
				
			||||||
 | 
					                        ECDHE-ECDSA-CHACHA20-POLY1305:ECDHE-RSA-CHACHA20-POLY1305:\
 | 
				
			||||||
 | 
					                        DHE-RSA-CHACHA20-POLY1305:ECDHE-ECDSA-AES128-GCM-SHA256:\
 | 
				
			||||||
 | 
					                        ECDHE-RSA-AES128-GCM-SHA256:DHE-RSA-AES128-GCM-SHA256:\
 | 
				
			||||||
 | 
					                        ECDHE-ECDSA-AES256-SHA384:ECDHE-RSA-AES256-SHA384:\
 | 
				
			||||||
 | 
					                        DHE-RSA-AES256-SHA256:ECDHE-ECDSA-AES128-SHA256:\
 | 
				
			||||||
 | 
					                        ECDHE-RSA-AES128-SHA256:DHE-RSA-AES128-SHA256:\
 | 
				
			||||||
 | 
					                        ECDHE-ECDSA-AES256-SHA:ECDHE-RSA-AES256-SHA",
 | 
				
			||||||
 | 
					            "ecdh_curve": "prime256v1",
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        ws_ssl1 = websocket.WebSocket(sslopt=ssloptions)
 | 
				
			||||||
 | 
					        ws_ssl1.connect("wss://api.bitfinex.com/ws/2")
 | 
				
			||||||
 | 
					        ws_ssl1.send("Hello")
 | 
				
			||||||
 | 
					        ws_ssl1.close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        ws_ssl2 = websocket.WebSocket(sslopt={"check_hostname": True})
 | 
				
			||||||
 | 
					        ws_ssl2.connect("wss://api.bitfinex.com/ws/2")
 | 
				
			||||||
 | 
					        ws_ssl2.close
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_proxy_info(self):
 | 
				
			||||||
 | 
					        self.assertEqual(
 | 
				
			||||||
 | 
					            proxy_info(
 | 
				
			||||||
 | 
					                http_proxy_host="127.0.0.1", http_proxy_port="8080", proxy_type="http"
 | 
				
			||||||
 | 
					            ).proxy_protocol,
 | 
				
			||||||
 | 
					            "http",
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertRaises(
 | 
				
			||||||
 | 
					            ProxyError,
 | 
				
			||||||
 | 
					            proxy_info,
 | 
				
			||||||
 | 
					            http_proxy_host="127.0.0.1",
 | 
				
			||||||
 | 
					            http_proxy_port="8080",
 | 
				
			||||||
 | 
					            proxy_type="badval",
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertEqual(
 | 
				
			||||||
 | 
					            proxy_info(
 | 
				
			||||||
 | 
					                http_proxy_host="example.com", http_proxy_port="8080", proxy_type="http"
 | 
				
			||||||
 | 
					            ).proxy_host,
 | 
				
			||||||
 | 
					            "example.com",
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertEqual(
 | 
				
			||||||
 | 
					            proxy_info(
 | 
				
			||||||
 | 
					                http_proxy_host="127.0.0.1", http_proxy_port="8080", proxy_type="http"
 | 
				
			||||||
 | 
					            ).proxy_port,
 | 
				
			||||||
 | 
					            "8080",
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertEqual(
 | 
				
			||||||
 | 
					            proxy_info(
 | 
				
			||||||
 | 
					                http_proxy_host="127.0.0.1", http_proxy_port="8080", proxy_type="http"
 | 
				
			||||||
 | 
					            ).auth,
 | 
				
			||||||
 | 
					            None,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertEqual(
 | 
				
			||||||
 | 
					            proxy_info(
 | 
				
			||||||
 | 
					                http_proxy_host="127.0.0.1",
 | 
				
			||||||
 | 
					                http_proxy_port="8080",
 | 
				
			||||||
 | 
					                proxy_type="http",
 | 
				
			||||||
 | 
					                http_proxy_auth=("my_username123", "my_pass321"),
 | 
				
			||||||
 | 
					            ).auth[0],
 | 
				
			||||||
 | 
					            "my_username123",
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertEqual(
 | 
				
			||||||
 | 
					            proxy_info(
 | 
				
			||||||
 | 
					                http_proxy_host="127.0.0.1",
 | 
				
			||||||
 | 
					                http_proxy_port="8080",
 | 
				
			||||||
 | 
					                proxy_type="http",
 | 
				
			||||||
 | 
					                http_proxy_auth=("my_username123", "my_pass321"),
 | 
				
			||||||
 | 
					            ).auth[1],
 | 
				
			||||||
 | 
					            "my_pass321",
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
					    unittest.main()
 | 
				
			||||||
							
								
								
									
										464
									
								
								src/libs/websocket/tests/test_url.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										464
									
								
								src/libs/websocket/tests/test_url.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,464 @@
 | 
				
			|||||||
 | 
					# -*- coding: utf-8 -*-
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					import unittest
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from websocket._url import (
 | 
				
			||||||
 | 
					    _is_address_in_network,
 | 
				
			||||||
 | 
					    _is_no_proxy_host,
 | 
				
			||||||
 | 
					    get_proxy_info,
 | 
				
			||||||
 | 
					    parse_url,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					from websocket._exceptions import WebSocketProxyException
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					test_url.py
 | 
				
			||||||
 | 
					websocket - WebSocket client library for Python
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Copyright 2024 engn33r
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					You may obtain a copy of the License at
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					limitations under the License.
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class UrlTest(unittest.TestCase):
 | 
				
			||||||
 | 
					    def test_address_in_network(self):
 | 
				
			||||||
 | 
					        self.assertTrue(_is_address_in_network("127.0.0.1", "127.0.0.0/8"))
 | 
				
			||||||
 | 
					        self.assertTrue(_is_address_in_network("127.1.0.1", "127.0.0.0/8"))
 | 
				
			||||||
 | 
					        self.assertFalse(_is_address_in_network("127.1.0.1", "127.0.0.0/24"))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_parse_url(self):
 | 
				
			||||||
 | 
					        p = parse_url("ws://www.example.com/r")
 | 
				
			||||||
 | 
					        self.assertEqual(p[0], "www.example.com")
 | 
				
			||||||
 | 
					        self.assertEqual(p[1], 80)
 | 
				
			||||||
 | 
					        self.assertEqual(p[2], "/r")
 | 
				
			||||||
 | 
					        self.assertEqual(p[3], False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        p = parse_url("ws://www.example.com/r/")
 | 
				
			||||||
 | 
					        self.assertEqual(p[0], "www.example.com")
 | 
				
			||||||
 | 
					        self.assertEqual(p[1], 80)
 | 
				
			||||||
 | 
					        self.assertEqual(p[2], "/r/")
 | 
				
			||||||
 | 
					        self.assertEqual(p[3], False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        p = parse_url("ws://www.example.com/")
 | 
				
			||||||
 | 
					        self.assertEqual(p[0], "www.example.com")
 | 
				
			||||||
 | 
					        self.assertEqual(p[1], 80)
 | 
				
			||||||
 | 
					        self.assertEqual(p[2], "/")
 | 
				
			||||||
 | 
					        self.assertEqual(p[3], False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        p = parse_url("ws://www.example.com")
 | 
				
			||||||
 | 
					        self.assertEqual(p[0], "www.example.com")
 | 
				
			||||||
 | 
					        self.assertEqual(p[1], 80)
 | 
				
			||||||
 | 
					        self.assertEqual(p[2], "/")
 | 
				
			||||||
 | 
					        self.assertEqual(p[3], False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        p = parse_url("ws://www.example.com:8080/r")
 | 
				
			||||||
 | 
					        self.assertEqual(p[0], "www.example.com")
 | 
				
			||||||
 | 
					        self.assertEqual(p[1], 8080)
 | 
				
			||||||
 | 
					        self.assertEqual(p[2], "/r")
 | 
				
			||||||
 | 
					        self.assertEqual(p[3], False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        p = parse_url("ws://www.example.com:8080/")
 | 
				
			||||||
 | 
					        self.assertEqual(p[0], "www.example.com")
 | 
				
			||||||
 | 
					        self.assertEqual(p[1], 8080)
 | 
				
			||||||
 | 
					        self.assertEqual(p[2], "/")
 | 
				
			||||||
 | 
					        self.assertEqual(p[3], False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        p = parse_url("ws://www.example.com:8080")
 | 
				
			||||||
 | 
					        self.assertEqual(p[0], "www.example.com")
 | 
				
			||||||
 | 
					        self.assertEqual(p[1], 8080)
 | 
				
			||||||
 | 
					        self.assertEqual(p[2], "/")
 | 
				
			||||||
 | 
					        self.assertEqual(p[3], False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        p = parse_url("wss://www.example.com:8080/r")
 | 
				
			||||||
 | 
					        self.assertEqual(p[0], "www.example.com")
 | 
				
			||||||
 | 
					        self.assertEqual(p[1], 8080)
 | 
				
			||||||
 | 
					        self.assertEqual(p[2], "/r")
 | 
				
			||||||
 | 
					        self.assertEqual(p[3], True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        p = parse_url("wss://www.example.com:8080/r?key=value")
 | 
				
			||||||
 | 
					        self.assertEqual(p[0], "www.example.com")
 | 
				
			||||||
 | 
					        self.assertEqual(p[1], 8080)
 | 
				
			||||||
 | 
					        self.assertEqual(p[2], "/r?key=value")
 | 
				
			||||||
 | 
					        self.assertEqual(p[3], True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.assertRaises(ValueError, parse_url, "http://www.example.com/r")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        p = parse_url("ws://[2a03:4000:123:83::3]/r")
 | 
				
			||||||
 | 
					        self.assertEqual(p[0], "2a03:4000:123:83::3")
 | 
				
			||||||
 | 
					        self.assertEqual(p[1], 80)
 | 
				
			||||||
 | 
					        self.assertEqual(p[2], "/r")
 | 
				
			||||||
 | 
					        self.assertEqual(p[3], False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        p = parse_url("ws://[2a03:4000:123:83::3]:8080/r")
 | 
				
			||||||
 | 
					        self.assertEqual(p[0], "2a03:4000:123:83::3")
 | 
				
			||||||
 | 
					        self.assertEqual(p[1], 8080)
 | 
				
			||||||
 | 
					        self.assertEqual(p[2], "/r")
 | 
				
			||||||
 | 
					        self.assertEqual(p[3], False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        p = parse_url("wss://[2a03:4000:123:83::3]/r")
 | 
				
			||||||
 | 
					        self.assertEqual(p[0], "2a03:4000:123:83::3")
 | 
				
			||||||
 | 
					        self.assertEqual(p[1], 443)
 | 
				
			||||||
 | 
					        self.assertEqual(p[2], "/r")
 | 
				
			||||||
 | 
					        self.assertEqual(p[3], True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        p = parse_url("wss://[2a03:4000:123:83::3]:8080/r")
 | 
				
			||||||
 | 
					        self.assertEqual(p[0], "2a03:4000:123:83::3")
 | 
				
			||||||
 | 
					        self.assertEqual(p[1], 8080)
 | 
				
			||||||
 | 
					        self.assertEqual(p[2], "/r")
 | 
				
			||||||
 | 
					        self.assertEqual(p[3], True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class IsNoProxyHostTest(unittest.TestCase):
 | 
				
			||||||
 | 
					    def setUp(self):
 | 
				
			||||||
 | 
					        self.no_proxy = os.environ.get("no_proxy", None)
 | 
				
			||||||
 | 
					        if "no_proxy" in os.environ:
 | 
				
			||||||
 | 
					            del os.environ["no_proxy"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def tearDown(self):
 | 
				
			||||||
 | 
					        if self.no_proxy:
 | 
				
			||||||
 | 
					            os.environ["no_proxy"] = self.no_proxy
 | 
				
			||||||
 | 
					        elif "no_proxy" in os.environ:
 | 
				
			||||||
 | 
					            del os.environ["no_proxy"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_match_all(self):
 | 
				
			||||||
 | 
					        self.assertTrue(_is_no_proxy_host("any.websocket.org", ["*"]))
 | 
				
			||||||
 | 
					        self.assertTrue(_is_no_proxy_host("192.168.0.1", ["*"]))
 | 
				
			||||||
 | 
					        self.assertFalse(_is_no_proxy_host("192.168.0.1", ["192.168.1.1"]))
 | 
				
			||||||
 | 
					        self.assertFalse(
 | 
				
			||||||
 | 
					            _is_no_proxy_host("any.websocket.org", ["other.websocket.org"])
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertTrue(
 | 
				
			||||||
 | 
					            _is_no_proxy_host("any.websocket.org", ["other.websocket.org", "*"])
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        os.environ["no_proxy"] = "*"
 | 
				
			||||||
 | 
					        self.assertTrue(_is_no_proxy_host("any.websocket.org", None))
 | 
				
			||||||
 | 
					        self.assertTrue(_is_no_proxy_host("192.168.0.1", None))
 | 
				
			||||||
 | 
					        os.environ["no_proxy"] = "other.websocket.org, *"
 | 
				
			||||||
 | 
					        self.assertTrue(_is_no_proxy_host("any.websocket.org", None))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_ip_address(self):
 | 
				
			||||||
 | 
					        self.assertTrue(_is_no_proxy_host("127.0.0.1", ["127.0.0.1"]))
 | 
				
			||||||
 | 
					        self.assertFalse(_is_no_proxy_host("127.0.0.2", ["127.0.0.1"]))
 | 
				
			||||||
 | 
					        self.assertTrue(
 | 
				
			||||||
 | 
					            _is_no_proxy_host("127.0.0.1", ["other.websocket.org", "127.0.0.1"])
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertFalse(
 | 
				
			||||||
 | 
					            _is_no_proxy_host("127.0.0.2", ["other.websocket.org", "127.0.0.1"])
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        os.environ["no_proxy"] = "127.0.0.1"
 | 
				
			||||||
 | 
					        self.assertTrue(_is_no_proxy_host("127.0.0.1", None))
 | 
				
			||||||
 | 
					        self.assertFalse(_is_no_proxy_host("127.0.0.2", None))
 | 
				
			||||||
 | 
					        os.environ["no_proxy"] = "other.websocket.org, 127.0.0.1"
 | 
				
			||||||
 | 
					        self.assertTrue(_is_no_proxy_host("127.0.0.1", None))
 | 
				
			||||||
 | 
					        self.assertFalse(_is_no_proxy_host("127.0.0.2", None))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_ip_address_in_range(self):
 | 
				
			||||||
 | 
					        self.assertTrue(_is_no_proxy_host("127.0.0.1", ["127.0.0.0/8"]))
 | 
				
			||||||
 | 
					        self.assertTrue(_is_no_proxy_host("127.0.0.2", ["127.0.0.0/8"]))
 | 
				
			||||||
 | 
					        self.assertFalse(_is_no_proxy_host("127.1.0.1", ["127.0.0.0/24"]))
 | 
				
			||||||
 | 
					        os.environ["no_proxy"] = "127.0.0.0/8"
 | 
				
			||||||
 | 
					        self.assertTrue(_is_no_proxy_host("127.0.0.1", None))
 | 
				
			||||||
 | 
					        self.assertTrue(_is_no_proxy_host("127.0.0.2", None))
 | 
				
			||||||
 | 
					        os.environ["no_proxy"] = "127.0.0.0/24"
 | 
				
			||||||
 | 
					        self.assertFalse(_is_no_proxy_host("127.1.0.1", None))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_hostname_match(self):
 | 
				
			||||||
 | 
					        self.assertTrue(_is_no_proxy_host("my.websocket.org", ["my.websocket.org"]))
 | 
				
			||||||
 | 
					        self.assertTrue(
 | 
				
			||||||
 | 
					            _is_no_proxy_host(
 | 
				
			||||||
 | 
					                "my.websocket.org", ["other.websocket.org", "my.websocket.org"]
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertFalse(_is_no_proxy_host("my.websocket.org", ["other.websocket.org"]))
 | 
				
			||||||
 | 
					        os.environ["no_proxy"] = "my.websocket.org"
 | 
				
			||||||
 | 
					        self.assertTrue(_is_no_proxy_host("my.websocket.org", None))
 | 
				
			||||||
 | 
					        self.assertFalse(_is_no_proxy_host("other.websocket.org", None))
 | 
				
			||||||
 | 
					        os.environ["no_proxy"] = "other.websocket.org, my.websocket.org"
 | 
				
			||||||
 | 
					        self.assertTrue(_is_no_proxy_host("my.websocket.org", None))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_hostname_match_domain(self):
 | 
				
			||||||
 | 
					        self.assertTrue(_is_no_proxy_host("any.websocket.org", [".websocket.org"]))
 | 
				
			||||||
 | 
					        self.assertTrue(_is_no_proxy_host("my.other.websocket.org", [".websocket.org"]))
 | 
				
			||||||
 | 
					        self.assertTrue(
 | 
				
			||||||
 | 
					            _is_no_proxy_host(
 | 
				
			||||||
 | 
					                "any.websocket.org", ["my.websocket.org", ".websocket.org"]
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertFalse(_is_no_proxy_host("any.websocket.com", [".websocket.org"]))
 | 
				
			||||||
 | 
					        os.environ["no_proxy"] = ".websocket.org"
 | 
				
			||||||
 | 
					        self.assertTrue(_is_no_proxy_host("any.websocket.org", None))
 | 
				
			||||||
 | 
					        self.assertTrue(_is_no_proxy_host("my.other.websocket.org", None))
 | 
				
			||||||
 | 
					        self.assertFalse(_is_no_proxy_host("any.websocket.com", None))
 | 
				
			||||||
 | 
					        os.environ["no_proxy"] = "my.websocket.org, .websocket.org"
 | 
				
			||||||
 | 
					        self.assertTrue(_is_no_proxy_host("any.websocket.org", None))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ProxyInfoTest(unittest.TestCase):
 | 
				
			||||||
 | 
					    def setUp(self):
 | 
				
			||||||
 | 
					        self.http_proxy = os.environ.get("http_proxy", None)
 | 
				
			||||||
 | 
					        self.https_proxy = os.environ.get("https_proxy", None)
 | 
				
			||||||
 | 
					        self.no_proxy = os.environ.get("no_proxy", None)
 | 
				
			||||||
 | 
					        if "http_proxy" in os.environ:
 | 
				
			||||||
 | 
					            del os.environ["http_proxy"]
 | 
				
			||||||
 | 
					        if "https_proxy" in os.environ:
 | 
				
			||||||
 | 
					            del os.environ["https_proxy"]
 | 
				
			||||||
 | 
					        if "no_proxy" in os.environ:
 | 
				
			||||||
 | 
					            del os.environ["no_proxy"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def tearDown(self):
 | 
				
			||||||
 | 
					        if self.http_proxy:
 | 
				
			||||||
 | 
					            os.environ["http_proxy"] = self.http_proxy
 | 
				
			||||||
 | 
					        elif "http_proxy" in os.environ:
 | 
				
			||||||
 | 
					            del os.environ["http_proxy"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.https_proxy:
 | 
				
			||||||
 | 
					            os.environ["https_proxy"] = self.https_proxy
 | 
				
			||||||
 | 
					        elif "https_proxy" in os.environ:
 | 
				
			||||||
 | 
					            del os.environ["https_proxy"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.no_proxy:
 | 
				
			||||||
 | 
					            os.environ["no_proxy"] = self.no_proxy
 | 
				
			||||||
 | 
					        elif "no_proxy" in os.environ:
 | 
				
			||||||
 | 
					            del os.environ["no_proxy"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_proxy_from_args(self):
 | 
				
			||||||
 | 
					        self.assertRaises(
 | 
				
			||||||
 | 
					            WebSocketProxyException,
 | 
				
			||||||
 | 
					            get_proxy_info,
 | 
				
			||||||
 | 
					            "echo.websocket.events",
 | 
				
			||||||
 | 
					            False,
 | 
				
			||||||
 | 
					            proxy_host="localhost",
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertEqual(
 | 
				
			||||||
 | 
					            get_proxy_info(
 | 
				
			||||||
 | 
					                "echo.websocket.events", False, proxy_host="localhost", proxy_port=3128
 | 
				
			||||||
 | 
					            ),
 | 
				
			||||||
 | 
					            ("localhost", 3128, None),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertEqual(
 | 
				
			||||||
 | 
					            get_proxy_info(
 | 
				
			||||||
 | 
					                "echo.websocket.events", True, proxy_host="localhost", proxy_port=3128
 | 
				
			||||||
 | 
					            ),
 | 
				
			||||||
 | 
					            ("localhost", 3128, None),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.assertEqual(
 | 
				
			||||||
 | 
					            get_proxy_info(
 | 
				
			||||||
 | 
					                "echo.websocket.events",
 | 
				
			||||||
 | 
					                False,
 | 
				
			||||||
 | 
					                proxy_host="localhost",
 | 
				
			||||||
 | 
					                proxy_port=9001,
 | 
				
			||||||
 | 
					                proxy_auth=("a", "b"),
 | 
				
			||||||
 | 
					            ),
 | 
				
			||||||
 | 
					            ("localhost", 9001, ("a", "b")),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertEqual(
 | 
				
			||||||
 | 
					            get_proxy_info(
 | 
				
			||||||
 | 
					                "echo.websocket.events",
 | 
				
			||||||
 | 
					                False,
 | 
				
			||||||
 | 
					                proxy_host="localhost",
 | 
				
			||||||
 | 
					                proxy_port=3128,
 | 
				
			||||||
 | 
					                proxy_auth=("a", "b"),
 | 
				
			||||||
 | 
					            ),
 | 
				
			||||||
 | 
					            ("localhost", 3128, ("a", "b")),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertEqual(
 | 
				
			||||||
 | 
					            get_proxy_info(
 | 
				
			||||||
 | 
					                "echo.websocket.events",
 | 
				
			||||||
 | 
					                True,
 | 
				
			||||||
 | 
					                proxy_host="localhost",
 | 
				
			||||||
 | 
					                proxy_port=8765,
 | 
				
			||||||
 | 
					                proxy_auth=("a", "b"),
 | 
				
			||||||
 | 
					            ),
 | 
				
			||||||
 | 
					            ("localhost", 8765, ("a", "b")),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertEqual(
 | 
				
			||||||
 | 
					            get_proxy_info(
 | 
				
			||||||
 | 
					                "echo.websocket.events",
 | 
				
			||||||
 | 
					                True,
 | 
				
			||||||
 | 
					                proxy_host="localhost",
 | 
				
			||||||
 | 
					                proxy_port=3128,
 | 
				
			||||||
 | 
					                proxy_auth=("a", "b"),
 | 
				
			||||||
 | 
					            ),
 | 
				
			||||||
 | 
					            ("localhost", 3128, ("a", "b")),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.assertEqual(
 | 
				
			||||||
 | 
					            get_proxy_info(
 | 
				
			||||||
 | 
					                "echo.websocket.events",
 | 
				
			||||||
 | 
					                True,
 | 
				
			||||||
 | 
					                proxy_host="localhost",
 | 
				
			||||||
 | 
					                proxy_port=3128,
 | 
				
			||||||
 | 
					                no_proxy=["example.com"],
 | 
				
			||||||
 | 
					                proxy_auth=("a", "b"),
 | 
				
			||||||
 | 
					            ),
 | 
				
			||||||
 | 
					            ("localhost", 3128, ("a", "b")),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertEqual(
 | 
				
			||||||
 | 
					            get_proxy_info(
 | 
				
			||||||
 | 
					                "echo.websocket.events",
 | 
				
			||||||
 | 
					                True,
 | 
				
			||||||
 | 
					                proxy_host="localhost",
 | 
				
			||||||
 | 
					                proxy_port=3128,
 | 
				
			||||||
 | 
					                no_proxy=["echo.websocket.events"],
 | 
				
			||||||
 | 
					                proxy_auth=("a", "b"),
 | 
				
			||||||
 | 
					            ),
 | 
				
			||||||
 | 
					            (None, 0, None),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.assertEqual(
 | 
				
			||||||
 | 
					            get_proxy_info(
 | 
				
			||||||
 | 
					                "echo.websocket.events",
 | 
				
			||||||
 | 
					                True,
 | 
				
			||||||
 | 
					                proxy_host="localhost",
 | 
				
			||||||
 | 
					                proxy_port=3128,
 | 
				
			||||||
 | 
					                no_proxy=[".websocket.events"],
 | 
				
			||||||
 | 
					            ),
 | 
				
			||||||
 | 
					            (None, 0, None),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_proxy_from_env(self):
 | 
				
			||||||
 | 
					        os.environ["http_proxy"] = "http://localhost/"
 | 
				
			||||||
 | 
					        self.assertEqual(
 | 
				
			||||||
 | 
					            get_proxy_info("echo.websocket.events", False), ("localhost", None, None)
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        os.environ["http_proxy"] = "http://localhost:3128/"
 | 
				
			||||||
 | 
					        self.assertEqual(
 | 
				
			||||||
 | 
					            get_proxy_info("echo.websocket.events", False), ("localhost", 3128, None)
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        os.environ["http_proxy"] = "http://localhost/"
 | 
				
			||||||
 | 
					        os.environ["https_proxy"] = "http://localhost2/"
 | 
				
			||||||
 | 
					        self.assertEqual(
 | 
				
			||||||
 | 
					            get_proxy_info("echo.websocket.events", False), ("localhost", None, None)
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        os.environ["http_proxy"] = "http://localhost:3128/"
 | 
				
			||||||
 | 
					        os.environ["https_proxy"] = "http://localhost2:3128/"
 | 
				
			||||||
 | 
					        self.assertEqual(
 | 
				
			||||||
 | 
					            get_proxy_info("echo.websocket.events", False), ("localhost", 3128, None)
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        os.environ["http_proxy"] = "http://localhost/"
 | 
				
			||||||
 | 
					        os.environ["https_proxy"] = "http://localhost2/"
 | 
				
			||||||
 | 
					        self.assertEqual(
 | 
				
			||||||
 | 
					            get_proxy_info("echo.websocket.events", True), ("localhost2", None, None)
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        os.environ["http_proxy"] = "http://localhost:3128/"
 | 
				
			||||||
 | 
					        os.environ["https_proxy"] = "http://localhost2:3128/"
 | 
				
			||||||
 | 
					        self.assertEqual(
 | 
				
			||||||
 | 
					            get_proxy_info("echo.websocket.events", True), ("localhost2", 3128, None)
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        os.environ["http_proxy"] = ""
 | 
				
			||||||
 | 
					        os.environ["https_proxy"] = "http://localhost2/"
 | 
				
			||||||
 | 
					        self.assertEqual(
 | 
				
			||||||
 | 
					            get_proxy_info("echo.websocket.events", True), ("localhost2", None, None)
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertEqual(
 | 
				
			||||||
 | 
					            get_proxy_info("echo.websocket.events", False), (None, 0, None)
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        os.environ["http_proxy"] = ""
 | 
				
			||||||
 | 
					        os.environ["https_proxy"] = "http://localhost2:3128/"
 | 
				
			||||||
 | 
					        self.assertEqual(
 | 
				
			||||||
 | 
					            get_proxy_info("echo.websocket.events", True), ("localhost2", 3128, None)
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertEqual(
 | 
				
			||||||
 | 
					            get_proxy_info("echo.websocket.events", False), (None, 0, None)
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        os.environ["http_proxy"] = "http://localhost/"
 | 
				
			||||||
 | 
					        os.environ["https_proxy"] = ""
 | 
				
			||||||
 | 
					        self.assertEqual(get_proxy_info("echo.websocket.events", True), (None, 0, None))
 | 
				
			||||||
 | 
					        self.assertEqual(
 | 
				
			||||||
 | 
					            get_proxy_info("echo.websocket.events", False), ("localhost", None, None)
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        os.environ["http_proxy"] = "http://localhost:3128/"
 | 
				
			||||||
 | 
					        os.environ["https_proxy"] = ""
 | 
				
			||||||
 | 
					        self.assertEqual(get_proxy_info("echo.websocket.events", True), (None, 0, None))
 | 
				
			||||||
 | 
					        self.assertEqual(
 | 
				
			||||||
 | 
					            get_proxy_info("echo.websocket.events", False), ("localhost", 3128, None)
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        os.environ["http_proxy"] = "http://a:b@localhost/"
 | 
				
			||||||
 | 
					        self.assertEqual(
 | 
				
			||||||
 | 
					            get_proxy_info("echo.websocket.events", False),
 | 
				
			||||||
 | 
					            ("localhost", None, ("a", "b")),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        os.environ["http_proxy"] = "http://a:b@localhost:3128/"
 | 
				
			||||||
 | 
					        self.assertEqual(
 | 
				
			||||||
 | 
					            get_proxy_info("echo.websocket.events", False),
 | 
				
			||||||
 | 
					            ("localhost", 3128, ("a", "b")),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        os.environ["http_proxy"] = "http://a:b@localhost/"
 | 
				
			||||||
 | 
					        os.environ["https_proxy"] = "http://a:b@localhost2/"
 | 
				
			||||||
 | 
					        self.assertEqual(
 | 
				
			||||||
 | 
					            get_proxy_info("echo.websocket.events", False),
 | 
				
			||||||
 | 
					            ("localhost", None, ("a", "b")),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        os.environ["http_proxy"] = "http://a:b@localhost:3128/"
 | 
				
			||||||
 | 
					        os.environ["https_proxy"] = "http://a:b@localhost2:3128/"
 | 
				
			||||||
 | 
					        self.assertEqual(
 | 
				
			||||||
 | 
					            get_proxy_info("echo.websocket.events", False),
 | 
				
			||||||
 | 
					            ("localhost", 3128, ("a", "b")),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        os.environ["http_proxy"] = "http://a:b@localhost/"
 | 
				
			||||||
 | 
					        os.environ["https_proxy"] = "http://a:b@localhost2/"
 | 
				
			||||||
 | 
					        self.assertEqual(
 | 
				
			||||||
 | 
					            get_proxy_info("echo.websocket.events", True),
 | 
				
			||||||
 | 
					            ("localhost2", None, ("a", "b")),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        os.environ["http_proxy"] = "http://a:b@localhost:3128/"
 | 
				
			||||||
 | 
					        os.environ["https_proxy"] = "http://a:b@localhost2:3128/"
 | 
				
			||||||
 | 
					        self.assertEqual(
 | 
				
			||||||
 | 
					            get_proxy_info("echo.websocket.events", True),
 | 
				
			||||||
 | 
					            ("localhost2", 3128, ("a", "b")),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        os.environ["http_proxy"] = (
 | 
				
			||||||
 | 
					            "http://john%40example.com:P%40SSWORD@localhost:3128/"
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        os.environ["https_proxy"] = (
 | 
				
			||||||
 | 
					            "http://john%40example.com:P%40SSWORD@localhost2:3128/"
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertEqual(
 | 
				
			||||||
 | 
					            get_proxy_info("echo.websocket.events", True),
 | 
				
			||||||
 | 
					            ("localhost2", 3128, ("john@example.com", "P@SSWORD")),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        os.environ["http_proxy"] = "http://a:b@localhost/"
 | 
				
			||||||
 | 
					        os.environ["https_proxy"] = "http://a:b@localhost2/"
 | 
				
			||||||
 | 
					        os.environ["no_proxy"] = "example1.com,example2.com"
 | 
				
			||||||
 | 
					        self.assertEqual(
 | 
				
			||||||
 | 
					            get_proxy_info("example.1.com", True), ("localhost2", None, ("a", "b"))
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        os.environ["http_proxy"] = "http://a:b@localhost:3128/"
 | 
				
			||||||
 | 
					        os.environ["https_proxy"] = "http://a:b@localhost2:3128/"
 | 
				
			||||||
 | 
					        os.environ["no_proxy"] = "example1.com,example2.com, echo.websocket.events"
 | 
				
			||||||
 | 
					        self.assertEqual(get_proxy_info("echo.websocket.events", True), (None, 0, None))
 | 
				
			||||||
 | 
					        os.environ["http_proxy"] = "http://a:b@localhost:3128/"
 | 
				
			||||||
 | 
					        os.environ["https_proxy"] = "http://a:b@localhost2:3128/"
 | 
				
			||||||
 | 
					        os.environ["no_proxy"] = "example1.com,example2.com, .websocket.events"
 | 
				
			||||||
 | 
					        self.assertEqual(get_proxy_info("echo.websocket.events", True), (None, 0, None))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        os.environ["http_proxy"] = "http://a:b@localhost:3128/"
 | 
				
			||||||
 | 
					        os.environ["https_proxy"] = "http://a:b@localhost2:3128/"
 | 
				
			||||||
 | 
					        os.environ["no_proxy"] = "127.0.0.0/8, 192.168.0.0/16"
 | 
				
			||||||
 | 
					        self.assertEqual(get_proxy_info("127.0.0.1", False), (None, 0, None))
 | 
				
			||||||
 | 
					        self.assertEqual(get_proxy_info("192.168.1.1", False), (None, 0, None))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
					    unittest.main()
 | 
				
			||||||
							
								
								
									
										497
									
								
								src/libs/websocket/tests/test_websocket.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										497
									
								
								src/libs/websocket/tests/test_websocket.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,497 @@
 | 
				
			|||||||
 | 
					# -*- coding: utf-8 -*-
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					import os.path
 | 
				
			||||||
 | 
					import socket
 | 
				
			||||||
 | 
					import unittest
 | 
				
			||||||
 | 
					from base64 import decodebytes as base64decode
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import websocket as ws
 | 
				
			||||||
 | 
					from websocket._exceptions import WebSocketBadStatusException, WebSocketAddressException
 | 
				
			||||||
 | 
					from websocket._handshake import _create_sec_websocket_key
 | 
				
			||||||
 | 
					from websocket._handshake import _validate as _validate_header
 | 
				
			||||||
 | 
					from websocket._http import read_headers
 | 
				
			||||||
 | 
					from websocket._utils import validate_utf8
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					test_websocket.py
 | 
				
			||||||
 | 
					websocket - WebSocket client library for Python
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Copyright 2024 engn33r
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					You may obtain a copy of the License at
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					limitations under the License.
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					try:
 | 
				
			||||||
 | 
					    import ssl
 | 
				
			||||||
 | 
					except ImportError:
 | 
				
			||||||
 | 
					    # dummy class of SSLError for ssl none-support environment.
 | 
				
			||||||
 | 
					    class SSLError(Exception):
 | 
				
			||||||
 | 
					        pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Skip test to access the internet unless TEST_WITH_INTERNET == 1
 | 
				
			||||||
 | 
					TEST_WITH_INTERNET = os.environ.get("TEST_WITH_INTERNET", "0") == "1"
 | 
				
			||||||
 | 
					# Skip tests relying on local websockets server unless LOCAL_WS_SERVER_PORT != -1
 | 
				
			||||||
 | 
					LOCAL_WS_SERVER_PORT = os.environ.get("LOCAL_WS_SERVER_PORT", "-1")
 | 
				
			||||||
 | 
					TEST_WITH_LOCAL_SERVER = LOCAL_WS_SERVER_PORT != "-1"
 | 
				
			||||||
 | 
					TRACEABLE = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def create_mask_key(_):
 | 
				
			||||||
 | 
					    return "abcd"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class SockMock:
 | 
				
			||||||
 | 
					    def __init__(self):
 | 
				
			||||||
 | 
					        self.data = []
 | 
				
			||||||
 | 
					        self.sent = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def add_packet(self, data):
 | 
				
			||||||
 | 
					        self.data.append(data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def gettimeout(self):
 | 
				
			||||||
 | 
					        return None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def recv(self, bufsize):
 | 
				
			||||||
 | 
					        if self.data:
 | 
				
			||||||
 | 
					            e = self.data.pop(0)
 | 
				
			||||||
 | 
					            if isinstance(e, Exception):
 | 
				
			||||||
 | 
					                raise e
 | 
				
			||||||
 | 
					            if len(e) > bufsize:
 | 
				
			||||||
 | 
					                self.data.insert(0, e[bufsize:])
 | 
				
			||||||
 | 
					            return e[:bufsize]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def send(self, data):
 | 
				
			||||||
 | 
					        self.sent.append(data)
 | 
				
			||||||
 | 
					        return len(data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def close(self):
 | 
				
			||||||
 | 
					        pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class HeaderSockMock(SockMock):
 | 
				
			||||||
 | 
					    def __init__(self, fname):
 | 
				
			||||||
 | 
					        SockMock.__init__(self)
 | 
				
			||||||
 | 
					        path = os.path.join(os.path.dirname(__file__), fname)
 | 
				
			||||||
 | 
					        with open(path, "rb") as f:
 | 
				
			||||||
 | 
					            self.add_packet(f.read())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class WebSocketTest(unittest.TestCase):
 | 
				
			||||||
 | 
					    def setUp(self):
 | 
				
			||||||
 | 
					        ws.enableTrace(TRACEABLE)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def tearDown(self):
 | 
				
			||||||
 | 
					        pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_default_timeout(self):
 | 
				
			||||||
 | 
					        self.assertEqual(ws.getdefaulttimeout(), None)
 | 
				
			||||||
 | 
					        ws.setdefaulttimeout(10)
 | 
				
			||||||
 | 
					        self.assertEqual(ws.getdefaulttimeout(), 10)
 | 
				
			||||||
 | 
					        ws.setdefaulttimeout(None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_ws_key(self):
 | 
				
			||||||
 | 
					        key = _create_sec_websocket_key()
 | 
				
			||||||
 | 
					        self.assertTrue(key != 24)
 | 
				
			||||||
 | 
					        self.assertTrue("¥n" not in key)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_nonce(self):
 | 
				
			||||||
 | 
					        """WebSocket key should be a random 16-byte nonce."""
 | 
				
			||||||
 | 
					        key = _create_sec_websocket_key()
 | 
				
			||||||
 | 
					        nonce = base64decode(key.encode("utf-8"))
 | 
				
			||||||
 | 
					        self.assertEqual(16, len(nonce))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_ws_utils(self):
 | 
				
			||||||
 | 
					        key = "c6b8hTg4EeGb2gQMztV1/g=="
 | 
				
			||||||
 | 
					        required_header = {
 | 
				
			||||||
 | 
					            "upgrade": "websocket",
 | 
				
			||||||
 | 
					            "connection": "upgrade",
 | 
				
			||||||
 | 
					            "sec-websocket-accept": "Kxep+hNu9n51529fGidYu7a3wO0=",
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        self.assertEqual(_validate_header(required_header, key, None), (True, None))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        header = required_header.copy()
 | 
				
			||||||
 | 
					        header["upgrade"] = "http"
 | 
				
			||||||
 | 
					        self.assertEqual(_validate_header(header, key, None), (False, None))
 | 
				
			||||||
 | 
					        del header["upgrade"]
 | 
				
			||||||
 | 
					        self.assertEqual(_validate_header(header, key, None), (False, None))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        header = required_header.copy()
 | 
				
			||||||
 | 
					        header["connection"] = "something"
 | 
				
			||||||
 | 
					        self.assertEqual(_validate_header(header, key, None), (False, None))
 | 
				
			||||||
 | 
					        del header["connection"]
 | 
				
			||||||
 | 
					        self.assertEqual(_validate_header(header, key, None), (False, None))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        header = required_header.copy()
 | 
				
			||||||
 | 
					        header["sec-websocket-accept"] = "something"
 | 
				
			||||||
 | 
					        self.assertEqual(_validate_header(header, key, None), (False, None))
 | 
				
			||||||
 | 
					        del header["sec-websocket-accept"]
 | 
				
			||||||
 | 
					        self.assertEqual(_validate_header(header, key, None), (False, None))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        header = required_header.copy()
 | 
				
			||||||
 | 
					        header["sec-websocket-protocol"] = "sub1"
 | 
				
			||||||
 | 
					        self.assertEqual(
 | 
				
			||||||
 | 
					            _validate_header(header, key, ["sub1", "sub2"]), (True, "sub1")
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        # This case will print out a logging error using the error() function, but that is expected
 | 
				
			||||||
 | 
					        self.assertEqual(_validate_header(header, key, ["sub2", "sub3"]), (False, None))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        header = required_header.copy()
 | 
				
			||||||
 | 
					        header["sec-websocket-protocol"] = "sUb1"
 | 
				
			||||||
 | 
					        self.assertEqual(
 | 
				
			||||||
 | 
					            _validate_header(header, key, ["Sub1", "suB2"]), (True, "sub1")
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        header = required_header.copy()
 | 
				
			||||||
 | 
					        # This case will print out a logging error using the error() function, but that is expected
 | 
				
			||||||
 | 
					        self.assertEqual(_validate_header(header, key, ["Sub1", "suB2"]), (False, None))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_read_header(self):
 | 
				
			||||||
 | 
					        status, header, _ = read_headers(HeaderSockMock("data/header01.txt"))
 | 
				
			||||||
 | 
					        self.assertEqual(status, 101)
 | 
				
			||||||
 | 
					        self.assertEqual(header["connection"], "Upgrade")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        status, header, _ = read_headers(HeaderSockMock("data/header03.txt"))
 | 
				
			||||||
 | 
					        self.assertEqual(status, 101)
 | 
				
			||||||
 | 
					        self.assertEqual(header["connection"], "Upgrade, Keep-Alive")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        HeaderSockMock("data/header02.txt")
 | 
				
			||||||
 | 
					        self.assertRaises(
 | 
				
			||||||
 | 
					            ws.WebSocketException, read_headers, HeaderSockMock("data/header02.txt")
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_send(self):
 | 
				
			||||||
 | 
					        # TODO: add longer frame data
 | 
				
			||||||
 | 
					        sock = ws.WebSocket()
 | 
				
			||||||
 | 
					        sock.set_mask_key(create_mask_key)
 | 
				
			||||||
 | 
					        s = sock.sock = HeaderSockMock("data/header01.txt")
 | 
				
			||||||
 | 
					        sock.send("Hello")
 | 
				
			||||||
 | 
					        self.assertEqual(s.sent[0], b"\x81\x85abcd)\x07\x0f\x08\x0e")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        sock.send("こんにちは")
 | 
				
			||||||
 | 
					        self.assertEqual(
 | 
				
			||||||
 | 
					            s.sent[1],
 | 
				
			||||||
 | 
					            b"\x81\x8fabcd\x82\xe3\xf0\x87\xe3\xf1\x80\xe5\xca\x81\xe2\xc5\x82\xe3\xcc",
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        #        sock.send("x" * 5000)
 | 
				
			||||||
 | 
					        #        self.assertEqual(s.sent[1], b'\x81\x8fabcd\x82\xe3\xf0\x87\xe3\xf1\x80\xe5\xca\x81\xe2\xc5\x82\xe3\xcc")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.assertEqual(sock.send_binary(b"1111111111101"), 19)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_recv(self):
 | 
				
			||||||
 | 
					        # TODO: add longer frame data
 | 
				
			||||||
 | 
					        sock = ws.WebSocket()
 | 
				
			||||||
 | 
					        s = sock.sock = SockMock()
 | 
				
			||||||
 | 
					        something = (
 | 
				
			||||||
 | 
					            b"\x81\x8fabcd\x82\xe3\xf0\x87\xe3\xf1\x80\xe5\xca\x81\xe2\xc5\x82\xe3\xcc"
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        s.add_packet(something)
 | 
				
			||||||
 | 
					        data = sock.recv()
 | 
				
			||||||
 | 
					        self.assertEqual(data, "こんにちは")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        s.add_packet(b"\x81\x85abcd)\x07\x0f\x08\x0e")
 | 
				
			||||||
 | 
					        data = sock.recv()
 | 
				
			||||||
 | 
					        self.assertEqual(data, "Hello")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
 | 
				
			||||||
 | 
					    def test_iter(self):
 | 
				
			||||||
 | 
					        count = 2
 | 
				
			||||||
 | 
					        s = ws.create_connection("wss://api.bitfinex.com/ws/2")
 | 
				
			||||||
 | 
					        s.send('{"event": "subscribe", "channel": "ticker"}')
 | 
				
			||||||
 | 
					        for _ in s:
 | 
				
			||||||
 | 
					            count -= 1
 | 
				
			||||||
 | 
					            if count == 0:
 | 
				
			||||||
 | 
					                break
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
 | 
				
			||||||
 | 
					    def test_next(self):
 | 
				
			||||||
 | 
					        sock = ws.create_connection("wss://api.bitfinex.com/ws/2")
 | 
				
			||||||
 | 
					        self.assertEqual(str, type(next(sock)))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_internal_recv_strict(self):
 | 
				
			||||||
 | 
					        sock = ws.WebSocket()
 | 
				
			||||||
 | 
					        s = sock.sock = SockMock()
 | 
				
			||||||
 | 
					        s.add_packet(b"foo")
 | 
				
			||||||
 | 
					        s.add_packet(socket.timeout())
 | 
				
			||||||
 | 
					        s.add_packet(b"bar")
 | 
				
			||||||
 | 
					        # s.add_packet(SSLError("The read operation timed out"))
 | 
				
			||||||
 | 
					        s.add_packet(b"baz")
 | 
				
			||||||
 | 
					        with self.assertRaises(ws.WebSocketTimeoutException):
 | 
				
			||||||
 | 
					            sock.frame_buffer.recv_strict(9)
 | 
				
			||||||
 | 
					        #     with self.assertRaises(SSLError):
 | 
				
			||||||
 | 
					        #         data = sock._recv_strict(9)
 | 
				
			||||||
 | 
					        data = sock.frame_buffer.recv_strict(9)
 | 
				
			||||||
 | 
					        self.assertEqual(data, b"foobarbaz")
 | 
				
			||||||
 | 
					        with self.assertRaises(ws.WebSocketConnectionClosedException):
 | 
				
			||||||
 | 
					            sock.frame_buffer.recv_strict(1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_recv_timeout(self):
 | 
				
			||||||
 | 
					        sock = ws.WebSocket()
 | 
				
			||||||
 | 
					        s = sock.sock = SockMock()
 | 
				
			||||||
 | 
					        s.add_packet(b"\x81")
 | 
				
			||||||
 | 
					        s.add_packet(socket.timeout())
 | 
				
			||||||
 | 
					        s.add_packet(b"\x8dabcd\x29\x07\x0f\x08\x0e")
 | 
				
			||||||
 | 
					        s.add_packet(socket.timeout())
 | 
				
			||||||
 | 
					        s.add_packet(b"\x4e\x43\x33\x0e\x10\x0f\x00\x40")
 | 
				
			||||||
 | 
					        with self.assertRaises(ws.WebSocketTimeoutException):
 | 
				
			||||||
 | 
					            sock.recv()
 | 
				
			||||||
 | 
					        with self.assertRaises(ws.WebSocketTimeoutException):
 | 
				
			||||||
 | 
					            sock.recv()
 | 
				
			||||||
 | 
					        data = sock.recv()
 | 
				
			||||||
 | 
					        self.assertEqual(data, "Hello, World!")
 | 
				
			||||||
 | 
					        with self.assertRaises(ws.WebSocketConnectionClosedException):
 | 
				
			||||||
 | 
					            sock.recv()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_recv_with_simple_fragmentation(self):
 | 
				
			||||||
 | 
					        sock = ws.WebSocket()
 | 
				
			||||||
 | 
					        s = sock.sock = SockMock()
 | 
				
			||||||
 | 
					        # OPCODE=TEXT, FIN=0, MSG="Brevity is "
 | 
				
			||||||
 | 
					        s.add_packet(b"\x01\x8babcd#\x10\x06\x12\x08\x16\x1aD\x08\x11C")
 | 
				
			||||||
 | 
					        # OPCODE=CONT, FIN=1, MSG="the soul of wit"
 | 
				
			||||||
 | 
					        s.add_packet(b"\x80\x8fabcd\x15\n\x06D\x12\r\x16\x08A\r\x05D\x16\x0b\x17")
 | 
				
			||||||
 | 
					        data = sock.recv()
 | 
				
			||||||
 | 
					        self.assertEqual(data, "Brevity is the soul of wit")
 | 
				
			||||||
 | 
					        with self.assertRaises(ws.WebSocketConnectionClosedException):
 | 
				
			||||||
 | 
					            sock.recv()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_recv_with_fire_event_of_fragmentation(self):
 | 
				
			||||||
 | 
					        sock = ws.WebSocket(fire_cont_frame=True)
 | 
				
			||||||
 | 
					        s = sock.sock = SockMock()
 | 
				
			||||||
 | 
					        # OPCODE=TEXT, FIN=0, MSG="Brevity is "
 | 
				
			||||||
 | 
					        s.add_packet(b"\x01\x8babcd#\x10\x06\x12\x08\x16\x1aD\x08\x11C")
 | 
				
			||||||
 | 
					        # OPCODE=CONT, FIN=0, MSG="Brevity is "
 | 
				
			||||||
 | 
					        s.add_packet(b"\x00\x8babcd#\x10\x06\x12\x08\x16\x1aD\x08\x11C")
 | 
				
			||||||
 | 
					        # OPCODE=CONT, FIN=1, MSG="the soul of wit"
 | 
				
			||||||
 | 
					        s.add_packet(b"\x80\x8fabcd\x15\n\x06D\x12\r\x16\x08A\r\x05D\x16\x0b\x17")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        _, data = sock.recv_data()
 | 
				
			||||||
 | 
					        self.assertEqual(data, b"Brevity is ")
 | 
				
			||||||
 | 
					        _, data = sock.recv_data()
 | 
				
			||||||
 | 
					        self.assertEqual(data, b"Brevity is ")
 | 
				
			||||||
 | 
					        _, data = sock.recv_data()
 | 
				
			||||||
 | 
					        self.assertEqual(data, b"the soul of wit")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # OPCODE=CONT, FIN=0, MSG="Brevity is "
 | 
				
			||||||
 | 
					        s.add_packet(b"\x80\x8babcd#\x10\x06\x12\x08\x16\x1aD\x08\x11C")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        with self.assertRaises(ws.WebSocketException):
 | 
				
			||||||
 | 
					            sock.recv_data()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        with self.assertRaises(ws.WebSocketConnectionClosedException):
 | 
				
			||||||
 | 
					            sock.recv()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_close(self):
 | 
				
			||||||
 | 
					        sock = ws.WebSocket()
 | 
				
			||||||
 | 
					        sock.connected = True
 | 
				
			||||||
 | 
					        sock.close
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        sock = ws.WebSocket()
 | 
				
			||||||
 | 
					        s = sock.sock = SockMock()
 | 
				
			||||||
 | 
					        sock.connected = True
 | 
				
			||||||
 | 
					        s.add_packet(b"\x88\x80\x17\x98p\x84")
 | 
				
			||||||
 | 
					        sock.recv()
 | 
				
			||||||
 | 
					        self.assertEqual(sock.connected, False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_recv_cont_fragmentation(self):
 | 
				
			||||||
 | 
					        sock = ws.WebSocket()
 | 
				
			||||||
 | 
					        s = sock.sock = SockMock()
 | 
				
			||||||
 | 
					        # OPCODE=CONT, FIN=1, MSG="the soul of wit"
 | 
				
			||||||
 | 
					        s.add_packet(b"\x80\x8fabcd\x15\n\x06D\x12\r\x16\x08A\r\x05D\x16\x0b\x17")
 | 
				
			||||||
 | 
					        self.assertRaises(ws.WebSocketException, sock.recv)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_recv_with_prolonged_fragmentation(self):
 | 
				
			||||||
 | 
					        sock = ws.WebSocket()
 | 
				
			||||||
 | 
					        s = sock.sock = SockMock()
 | 
				
			||||||
 | 
					        # OPCODE=TEXT, FIN=0, MSG="Once more unto the breach, "
 | 
				
			||||||
 | 
					        s.add_packet(
 | 
				
			||||||
 | 
					            b"\x01\x9babcd.\x0c\x00\x01A\x0f\x0c\x16\x04B\x16\n\x15\rC\x10\t\x07C\x06\x13\x07\x02\x07\tNC"
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        # OPCODE=CONT, FIN=0, MSG="dear friends, "
 | 
				
			||||||
 | 
					        s.add_packet(b"\x00\x8eabcd\x05\x07\x02\x16A\x04\x11\r\x04\x0c\x07\x17MB")
 | 
				
			||||||
 | 
					        # OPCODE=CONT, FIN=1, MSG="once more"
 | 
				
			||||||
 | 
					        s.add_packet(b"\x80\x89abcd\x0e\x0c\x00\x01A\x0f\x0c\x16\x04")
 | 
				
			||||||
 | 
					        data = sock.recv()
 | 
				
			||||||
 | 
					        self.assertEqual(data, "Once more unto the breach, dear friends, once more")
 | 
				
			||||||
 | 
					        with self.assertRaises(ws.WebSocketConnectionClosedException):
 | 
				
			||||||
 | 
					            sock.recv()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_recv_with_fragmentation_and_control_frame(self):
 | 
				
			||||||
 | 
					        sock = ws.WebSocket()
 | 
				
			||||||
 | 
					        sock.set_mask_key(create_mask_key)
 | 
				
			||||||
 | 
					        s = sock.sock = SockMock()
 | 
				
			||||||
 | 
					        # OPCODE=TEXT, FIN=0, MSG="Too much "
 | 
				
			||||||
 | 
					        s.add_packet(b"\x01\x89abcd5\r\x0cD\x0c\x17\x00\x0cA")
 | 
				
			||||||
 | 
					        # OPCODE=PING, FIN=1, MSG="Please PONG this"
 | 
				
			||||||
 | 
					        s.add_packet(b"\x89\x90abcd1\x0e\x06\x05\x12\x07C4.,$D\x15\n\n\x17")
 | 
				
			||||||
 | 
					        # OPCODE=CONT, FIN=1, MSG="of a good thing"
 | 
				
			||||||
 | 
					        s.add_packet(b"\x80\x8fabcd\x0e\x04C\x05A\x05\x0c\x0b\x05B\x17\x0c\x08\x0c\x04")
 | 
				
			||||||
 | 
					        data = sock.recv()
 | 
				
			||||||
 | 
					        self.assertEqual(data, "Too much of a good thing")
 | 
				
			||||||
 | 
					        with self.assertRaises(ws.WebSocketConnectionClosedException):
 | 
				
			||||||
 | 
					            sock.recv()
 | 
				
			||||||
 | 
					        self.assertEqual(
 | 
				
			||||||
 | 
					            s.sent[0], b"\x8a\x90abcd1\x0e\x06\x05\x12\x07C4.,$D\x15\n\n\x17"
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @unittest.skipUnless(
 | 
				
			||||||
 | 
					        TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    def test_websocket(self):
 | 
				
			||||||
 | 
					        s = ws.create_connection(f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}")
 | 
				
			||||||
 | 
					        self.assertNotEqual(s, None)
 | 
				
			||||||
 | 
					        s.send("Hello, World")
 | 
				
			||||||
 | 
					        result = s.next()
 | 
				
			||||||
 | 
					        s.fileno()
 | 
				
			||||||
 | 
					        self.assertEqual(result, "Hello, World")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        s.send("こにゃにゃちは、世界")
 | 
				
			||||||
 | 
					        result = s.recv()
 | 
				
			||||||
 | 
					        self.assertEqual(result, "こにゃにゃちは、世界")
 | 
				
			||||||
 | 
					        self.assertRaises(ValueError, s.send_close, -1, "")
 | 
				
			||||||
 | 
					        s.close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @unittest.skipUnless(
 | 
				
			||||||
 | 
					        TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    def test_ping_pong(self):
 | 
				
			||||||
 | 
					        s = ws.create_connection(f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}")
 | 
				
			||||||
 | 
					        self.assertNotEqual(s, None)
 | 
				
			||||||
 | 
					        s.ping("Hello")
 | 
				
			||||||
 | 
					        s.pong("Hi")
 | 
				
			||||||
 | 
					        s.close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
 | 
				
			||||||
 | 
					    def test_support_redirect(self):
 | 
				
			||||||
 | 
					        s = ws.WebSocket()
 | 
				
			||||||
 | 
					        self.assertRaises(WebSocketBadStatusException, s.connect, "ws://google.com/")
 | 
				
			||||||
 | 
					        # Need to find a URL that has a redirect code leading to a websocket
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
 | 
				
			||||||
 | 
					    def test_secure_websocket(self):
 | 
				
			||||||
 | 
					        s = ws.create_connection("wss://api.bitfinex.com/ws/2")
 | 
				
			||||||
 | 
					        self.assertNotEqual(s, None)
 | 
				
			||||||
 | 
					        self.assertTrue(isinstance(s.sock, ssl.SSLSocket))
 | 
				
			||||||
 | 
					        self.assertEqual(s.getstatus(), 101)
 | 
				
			||||||
 | 
					        self.assertNotEqual(s.getheaders(), None)
 | 
				
			||||||
 | 
					        s.settimeout(10)
 | 
				
			||||||
 | 
					        self.assertEqual(s.gettimeout(), 10)
 | 
				
			||||||
 | 
					        self.assertEqual(s.getsubprotocol(), None)
 | 
				
			||||||
 | 
					        s.abort()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @unittest.skipUnless(
 | 
				
			||||||
 | 
					        TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    def test_websocket_with_custom_header(self):
 | 
				
			||||||
 | 
					        s = ws.create_connection(
 | 
				
			||||||
 | 
					            f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}",
 | 
				
			||||||
 | 
					            headers={"User-Agent": "PythonWebsocketClient"},
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertNotEqual(s, None)
 | 
				
			||||||
 | 
					        self.assertEqual(s.getsubprotocol(), None)
 | 
				
			||||||
 | 
					        s.send("Hello, World")
 | 
				
			||||||
 | 
					        result = s.recv()
 | 
				
			||||||
 | 
					        self.assertEqual(result, "Hello, World")
 | 
				
			||||||
 | 
					        self.assertRaises(ValueError, s.close, -1, "")
 | 
				
			||||||
 | 
					        s.close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @unittest.skipUnless(
 | 
				
			||||||
 | 
					        TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    def test_after_close(self):
 | 
				
			||||||
 | 
					        s = ws.create_connection(f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}")
 | 
				
			||||||
 | 
					        self.assertNotEqual(s, None)
 | 
				
			||||||
 | 
					        s.close()
 | 
				
			||||||
 | 
					        self.assertRaises(ws.WebSocketConnectionClosedException, s.send, "Hello")
 | 
				
			||||||
 | 
					        self.assertRaises(ws.WebSocketConnectionClosedException, s.recv)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class SockOptTest(unittest.TestCase):
 | 
				
			||||||
 | 
					    @unittest.skipUnless(
 | 
				
			||||||
 | 
					        TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    def test_sockopt(self):
 | 
				
			||||||
 | 
					        sockopt = ((socket.IPPROTO_TCP, socket.TCP_NODELAY, 1),)
 | 
				
			||||||
 | 
					        s = ws.create_connection(
 | 
				
			||||||
 | 
					            f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}", sockopt=sockopt
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertNotEqual(
 | 
				
			||||||
 | 
					            s.sock.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY), 0
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        s.close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class UtilsTest(unittest.TestCase):
 | 
				
			||||||
 | 
					    def test_utf8_validator(self):
 | 
				
			||||||
 | 
					        state = validate_utf8(b"\xf0\x90\x80\x80")
 | 
				
			||||||
 | 
					        self.assertEqual(state, True)
 | 
				
			||||||
 | 
					        state = validate_utf8(
 | 
				
			||||||
 | 
					            b"\xce\xba\xe1\xbd\xb9\xcf\x83\xce\xbc\xce\xb5\xed\xa0\x80edited"
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertEqual(state, False)
 | 
				
			||||||
 | 
					        state = validate_utf8(b"")
 | 
				
			||||||
 | 
					        self.assertEqual(state, True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class HandshakeTest(unittest.TestCase):
 | 
				
			||||||
 | 
					    @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
 | 
				
			||||||
 | 
					    def test_http_ssl(self):
 | 
				
			||||||
 | 
					        websock1 = ws.WebSocket(
 | 
				
			||||||
 | 
					            sslopt={"cert_chain": ssl.get_default_verify_paths().capath},
 | 
				
			||||||
 | 
					            enable_multithread=False,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertRaises(ValueError, websock1.connect, "wss://api.bitfinex.com/ws/2")
 | 
				
			||||||
 | 
					        websock2 = ws.WebSocket(sslopt={"certfile": "myNonexistentCertFile"})
 | 
				
			||||||
 | 
					        self.assertRaises(
 | 
				
			||||||
 | 
					            FileNotFoundError, websock2.connect, "wss://api.bitfinex.com/ws/2"
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
 | 
				
			||||||
 | 
					    def test_manual_headers(self):
 | 
				
			||||||
 | 
					        websock3 = ws.WebSocket(
 | 
				
			||||||
 | 
					            sslopt={
 | 
				
			||||||
 | 
					                "ca_certs": ssl.get_default_verify_paths().cafile,
 | 
				
			||||||
 | 
					                "ca_cert_path": ssl.get_default_verify_paths().capath,
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertRaises(
 | 
				
			||||||
 | 
					            WebSocketBadStatusException,
 | 
				
			||||||
 | 
					            websock3.connect,
 | 
				
			||||||
 | 
					            "wss://api.bitfinex.com/ws/2",
 | 
				
			||||||
 | 
					            cookie="chocolate",
 | 
				
			||||||
 | 
					            origin="testing_websockets.com",
 | 
				
			||||||
 | 
					            host="echo.websocket.events/websocket-client-test",
 | 
				
			||||||
 | 
					            subprotocols=["testproto"],
 | 
				
			||||||
 | 
					            connection="Upgrade",
 | 
				
			||||||
 | 
					            header={
 | 
				
			||||||
 | 
					                "CustomHeader1": "123",
 | 
				
			||||||
 | 
					                "Cookie": "TestValue",
 | 
				
			||||||
 | 
					                "Sec-WebSocket-Key": "k9kFAUWNAMmf5OEMfTlOEA==",
 | 
				
			||||||
 | 
					                "Sec-WebSocket-Protocol": "newprotocol",
 | 
				
			||||||
 | 
					            },
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_ipv6(self):
 | 
				
			||||||
 | 
					        websock2 = ws.WebSocket()
 | 
				
			||||||
 | 
					        self.assertRaises(ValueError, websock2.connect, "2001:4860:4860::8888")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_bad_urls(self):
 | 
				
			||||||
 | 
					        websock3 = ws.WebSocket()
 | 
				
			||||||
 | 
					        self.assertRaises(ValueError, websock3.connect, "ws//example.com")
 | 
				
			||||||
 | 
					        self.assertRaises(WebSocketAddressException, websock3.connect, "ws://example")
 | 
				
			||||||
 | 
					        self.assertRaises(ValueError, websock3.connect, "example.com")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
					    unittest.main()
 | 
				
			||||||
							
								
								
									
										54
									
								
								src/libs/websocket_client.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										54
									
								
								src/libs/websocket_client.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,54 @@
 | 
				
			|||||||
 | 
					# Python imports
 | 
				
			||||||
 | 
					# import rel
 | 
				
			||||||
 | 
					import json
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Lib imports
 | 
				
			||||||
 | 
					import gi
 | 
				
			||||||
 | 
					from gi.repository import GLib
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Application imports
 | 
				
			||||||
 | 
					from . import websocket 
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class WebsocketClient:
 | 
				
			||||||
 | 
					    def __init__(self):
 | 
				
			||||||
 | 
					        self.ws = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def send(self, message: str):
 | 
				
			||||||
 | 
					        self.ws.send(message)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def set_callback(self, callback: object):
 | 
				
			||||||
 | 
					        self.respond = callback
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def on_message(self, ws, message: dict):
 | 
				
			||||||
 | 
					        GLib.idle_add(self.respond, message)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def on_error(self, ws, error: str):
 | 
				
			||||||
 | 
					        logger.debug(f"WS Error:  {error}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def on_close(self, ws, close_status_code: int, close_msg: str):
 | 
				
			||||||
 | 
					        logger.debug("WS Closed...")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def on_open(self, ws):
 | 
				
			||||||
 | 
					        logger.debug("WS opened connection...")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def close_client(self):
 | 
				
			||||||
 | 
					        self.ws.close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @daemon_threaded
 | 
				
			||||||
 | 
					    def start_client(self):
 | 
				
			||||||
 | 
					        # websocket.enableTrace(True)
 | 
				
			||||||
 | 
					        self.ws = websocket.WebSocketApp("ws://localhost:4114",
 | 
				
			||||||
 | 
					                                  on_open = self.on_open,
 | 
				
			||||||
 | 
					                                  on_message = self.on_message,
 | 
				
			||||||
 | 
					                                  on_error = self.on_error,
 | 
				
			||||||
 | 
					                                  on_close = self.on_close)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Set dispatcher to automatic reconnection, 5 second reconnect delay
 | 
				
			||||||
 | 
					        # self.ws.run_forever(dispatcher = rel, reconnect = 5)
 | 
				
			||||||
 | 
					        # rel.signal(2, rel.abort)  # Keyboard Interrupt
 | 
				
			||||||
 | 
					        # rel.dispatch()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.ws.run_forever(reconnect = 0.5)
 | 
				
			||||||
@@ -1,199 +0,0 @@
 | 
				
			|||||||
from __future__ import annotations
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import typing
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from .imports import lazy_import
 | 
					 | 
				
			||||||
from .version import version as __version__  # noqa: F401
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
__all__ = [
 | 
					 | 
				
			||||||
    # .client
 | 
					 | 
				
			||||||
    "ClientProtocol",
 | 
					 | 
				
			||||||
    # .datastructures
 | 
					 | 
				
			||||||
    "Headers",
 | 
					 | 
				
			||||||
    "HeadersLike",
 | 
					 | 
				
			||||||
    "MultipleValuesError",
 | 
					 | 
				
			||||||
    # .exceptions
 | 
					 | 
				
			||||||
    "ConcurrencyError",
 | 
					 | 
				
			||||||
    "ConnectionClosed",
 | 
					 | 
				
			||||||
    "ConnectionClosedError",
 | 
					 | 
				
			||||||
    "ConnectionClosedOK",
 | 
					 | 
				
			||||||
    "DuplicateParameter",
 | 
					 | 
				
			||||||
    "InvalidHandshake",
 | 
					 | 
				
			||||||
    "InvalidHeader",
 | 
					 | 
				
			||||||
    "InvalidHeaderFormat",
 | 
					 | 
				
			||||||
    "InvalidHeaderValue",
 | 
					 | 
				
			||||||
    "InvalidOrigin",
 | 
					 | 
				
			||||||
    "InvalidParameterName",
 | 
					 | 
				
			||||||
    "InvalidParameterValue",
 | 
					 | 
				
			||||||
    "InvalidState",
 | 
					 | 
				
			||||||
    "InvalidStatus",
 | 
					 | 
				
			||||||
    "InvalidUpgrade",
 | 
					 | 
				
			||||||
    "InvalidURI",
 | 
					 | 
				
			||||||
    "NegotiationError",
 | 
					 | 
				
			||||||
    "PayloadTooBig",
 | 
					 | 
				
			||||||
    "ProtocolError",
 | 
					 | 
				
			||||||
    "SecurityError",
 | 
					 | 
				
			||||||
    "WebSocketException",
 | 
					 | 
				
			||||||
    "WebSocketProtocolError",
 | 
					 | 
				
			||||||
    # .legacy.auth
 | 
					 | 
				
			||||||
    "BasicAuthWebSocketServerProtocol",
 | 
					 | 
				
			||||||
    "basic_auth_protocol_factory",
 | 
					 | 
				
			||||||
    # .legacy.client
 | 
					 | 
				
			||||||
    "WebSocketClientProtocol",
 | 
					 | 
				
			||||||
    "connect",
 | 
					 | 
				
			||||||
    "unix_connect",
 | 
					 | 
				
			||||||
    # .legacy.exceptions
 | 
					 | 
				
			||||||
    "AbortHandshake",
 | 
					 | 
				
			||||||
    "InvalidMessage",
 | 
					 | 
				
			||||||
    "InvalidStatusCode",
 | 
					 | 
				
			||||||
    "RedirectHandshake",
 | 
					 | 
				
			||||||
    # .legacy.protocol
 | 
					 | 
				
			||||||
    "WebSocketCommonProtocol",
 | 
					 | 
				
			||||||
    # .legacy.server
 | 
					 | 
				
			||||||
    "WebSocketServer",
 | 
					 | 
				
			||||||
    "WebSocketServerProtocol",
 | 
					 | 
				
			||||||
    "broadcast",
 | 
					 | 
				
			||||||
    "serve",
 | 
					 | 
				
			||||||
    "unix_serve",
 | 
					 | 
				
			||||||
    # .server
 | 
					 | 
				
			||||||
    "ServerProtocol",
 | 
					 | 
				
			||||||
    # .typing
 | 
					 | 
				
			||||||
    "Data",
 | 
					 | 
				
			||||||
    "ExtensionName",
 | 
					 | 
				
			||||||
    "ExtensionParameter",
 | 
					 | 
				
			||||||
    "LoggerLike",
 | 
					 | 
				
			||||||
    "StatusLike",
 | 
					 | 
				
			||||||
    "Origin",
 | 
					 | 
				
			||||||
    "Subprotocol",
 | 
					 | 
				
			||||||
]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# When type checking, import non-deprecated aliases eagerly. Else, import on demand.
 | 
					 | 
				
			||||||
if typing.TYPE_CHECKING:
 | 
					 | 
				
			||||||
    from .client import ClientProtocol
 | 
					 | 
				
			||||||
    from .datastructures import Headers, HeadersLike, MultipleValuesError
 | 
					 | 
				
			||||||
    from .exceptions import (
 | 
					 | 
				
			||||||
        ConcurrencyError,
 | 
					 | 
				
			||||||
        ConnectionClosed,
 | 
					 | 
				
			||||||
        ConnectionClosedError,
 | 
					 | 
				
			||||||
        ConnectionClosedOK,
 | 
					 | 
				
			||||||
        DuplicateParameter,
 | 
					 | 
				
			||||||
        InvalidHandshake,
 | 
					 | 
				
			||||||
        InvalidHeader,
 | 
					 | 
				
			||||||
        InvalidHeaderFormat,
 | 
					 | 
				
			||||||
        InvalidHeaderValue,
 | 
					 | 
				
			||||||
        InvalidOrigin,
 | 
					 | 
				
			||||||
        InvalidParameterName,
 | 
					 | 
				
			||||||
        InvalidParameterValue,
 | 
					 | 
				
			||||||
        InvalidState,
 | 
					 | 
				
			||||||
        InvalidStatus,
 | 
					 | 
				
			||||||
        InvalidUpgrade,
 | 
					 | 
				
			||||||
        InvalidURI,
 | 
					 | 
				
			||||||
        NegotiationError,
 | 
					 | 
				
			||||||
        PayloadTooBig,
 | 
					 | 
				
			||||||
        ProtocolError,
 | 
					 | 
				
			||||||
        SecurityError,
 | 
					 | 
				
			||||||
        WebSocketException,
 | 
					 | 
				
			||||||
        WebSocketProtocolError,
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
    from .legacy.auth import (
 | 
					 | 
				
			||||||
        BasicAuthWebSocketServerProtocol,
 | 
					 | 
				
			||||||
        basic_auth_protocol_factory,
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
    from .legacy.client import WebSocketClientProtocol, connect, unix_connect
 | 
					 | 
				
			||||||
    from .legacy.exceptions import (
 | 
					 | 
				
			||||||
        AbortHandshake,
 | 
					 | 
				
			||||||
        InvalidMessage,
 | 
					 | 
				
			||||||
        InvalidStatusCode,
 | 
					 | 
				
			||||||
        RedirectHandshake,
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
    from .legacy.protocol import WebSocketCommonProtocol
 | 
					 | 
				
			||||||
    from .legacy.server import (
 | 
					 | 
				
			||||||
        WebSocketServer,
 | 
					 | 
				
			||||||
        WebSocketServerProtocol,
 | 
					 | 
				
			||||||
        broadcast,
 | 
					 | 
				
			||||||
        serve,
 | 
					 | 
				
			||||||
        unix_serve,
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
    from .server import ServerProtocol
 | 
					 | 
				
			||||||
    from .typing import (
 | 
					 | 
				
			||||||
        Data,
 | 
					 | 
				
			||||||
        ExtensionName,
 | 
					 | 
				
			||||||
        ExtensionParameter,
 | 
					 | 
				
			||||||
        LoggerLike,
 | 
					 | 
				
			||||||
        Origin,
 | 
					 | 
				
			||||||
        StatusLike,
 | 
					 | 
				
			||||||
        Subprotocol,
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
else:
 | 
					 | 
				
			||||||
    lazy_import(
 | 
					 | 
				
			||||||
        globals(),
 | 
					 | 
				
			||||||
        aliases={
 | 
					 | 
				
			||||||
            # .client
 | 
					 | 
				
			||||||
            "ClientProtocol": ".client",
 | 
					 | 
				
			||||||
            # .datastructures
 | 
					 | 
				
			||||||
            "Headers": ".datastructures",
 | 
					 | 
				
			||||||
            "HeadersLike": ".datastructures",
 | 
					 | 
				
			||||||
            "MultipleValuesError": ".datastructures",
 | 
					 | 
				
			||||||
            # .exceptions
 | 
					 | 
				
			||||||
            "ConcurrencyError": ".exceptions",
 | 
					 | 
				
			||||||
            "ConnectionClosed": ".exceptions",
 | 
					 | 
				
			||||||
            "ConnectionClosedError": ".exceptions",
 | 
					 | 
				
			||||||
            "ConnectionClosedOK": ".exceptions",
 | 
					 | 
				
			||||||
            "DuplicateParameter": ".exceptions",
 | 
					 | 
				
			||||||
            "InvalidHandshake": ".exceptions",
 | 
					 | 
				
			||||||
            "InvalidHeader": ".exceptions",
 | 
					 | 
				
			||||||
            "InvalidHeaderFormat": ".exceptions",
 | 
					 | 
				
			||||||
            "InvalidHeaderValue": ".exceptions",
 | 
					 | 
				
			||||||
            "InvalidOrigin": ".exceptions",
 | 
					 | 
				
			||||||
            "InvalidParameterName": ".exceptions",
 | 
					 | 
				
			||||||
            "InvalidParameterValue": ".exceptions",
 | 
					 | 
				
			||||||
            "InvalidState": ".exceptions",
 | 
					 | 
				
			||||||
            "InvalidStatus": ".exceptions",
 | 
					 | 
				
			||||||
            "InvalidUpgrade": ".exceptions",
 | 
					 | 
				
			||||||
            "InvalidURI": ".exceptions",
 | 
					 | 
				
			||||||
            "NegotiationError": ".exceptions",
 | 
					 | 
				
			||||||
            "PayloadTooBig": ".exceptions",
 | 
					 | 
				
			||||||
            "ProtocolError": ".exceptions",
 | 
					 | 
				
			||||||
            "SecurityError": ".exceptions",
 | 
					 | 
				
			||||||
            "WebSocketException": ".exceptions",
 | 
					 | 
				
			||||||
            "WebSocketProtocolError": ".exceptions",
 | 
					 | 
				
			||||||
            # .legacy.auth
 | 
					 | 
				
			||||||
            "BasicAuthWebSocketServerProtocol": ".legacy.auth",
 | 
					 | 
				
			||||||
            "basic_auth_protocol_factory": ".legacy.auth",
 | 
					 | 
				
			||||||
            # .legacy.client
 | 
					 | 
				
			||||||
            "WebSocketClientProtocol": ".legacy.client",
 | 
					 | 
				
			||||||
            "connect": ".legacy.client",
 | 
					 | 
				
			||||||
            "unix_connect": ".legacy.client",
 | 
					 | 
				
			||||||
            # .legacy.exceptions
 | 
					 | 
				
			||||||
            "AbortHandshake": ".legacy.exceptions",
 | 
					 | 
				
			||||||
            "InvalidMessage": ".legacy.exceptions",
 | 
					 | 
				
			||||||
            "InvalidStatusCode": ".legacy.exceptions",
 | 
					 | 
				
			||||||
            "RedirectHandshake": ".legacy.exceptions",
 | 
					 | 
				
			||||||
            # .legacy.protocol
 | 
					 | 
				
			||||||
            "WebSocketCommonProtocol": ".legacy.protocol",
 | 
					 | 
				
			||||||
            # .legacy.server
 | 
					 | 
				
			||||||
            "WebSocketServer": ".legacy.server",
 | 
					 | 
				
			||||||
            "WebSocketServerProtocol": ".legacy.server",
 | 
					 | 
				
			||||||
            "broadcast": ".legacy.server",
 | 
					 | 
				
			||||||
            "serve": ".legacy.server",
 | 
					 | 
				
			||||||
            "unix_serve": ".legacy.server",
 | 
					 | 
				
			||||||
            # .server
 | 
					 | 
				
			||||||
            "ServerProtocol": ".server",
 | 
					 | 
				
			||||||
            # .typing
 | 
					 | 
				
			||||||
            "Data": ".typing",
 | 
					 | 
				
			||||||
            "ExtensionName": ".typing",
 | 
					 | 
				
			||||||
            "ExtensionParameter": ".typing",
 | 
					 | 
				
			||||||
            "LoggerLike": ".typing",
 | 
					 | 
				
			||||||
            "Origin": ".typing",
 | 
					 | 
				
			||||||
            "StatusLike": ".typing",
 | 
					 | 
				
			||||||
            "Subprotocol": ".typing",
 | 
					 | 
				
			||||||
        },
 | 
					 | 
				
			||||||
        deprecated_aliases={
 | 
					 | 
				
			||||||
            # deprecated in 9.0 - 2021-09-01
 | 
					 | 
				
			||||||
            "framing": ".legacy",
 | 
					 | 
				
			||||||
            "handshake": ".legacy",
 | 
					 | 
				
			||||||
            "parse_uri": ".uri",
 | 
					 | 
				
			||||||
            "WebSocketURI": ".uri",
 | 
					 | 
				
			||||||
        },
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
@@ -1,159 +0,0 @@
 | 
				
			|||||||
from __future__ import annotations
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import argparse
 | 
					 | 
				
			||||||
import os
 | 
					 | 
				
			||||||
import signal
 | 
					 | 
				
			||||||
import sys
 | 
					 | 
				
			||||||
import threading
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
try:
 | 
					 | 
				
			||||||
    import readline  # noqa: F401
 | 
					 | 
				
			||||||
except ImportError:  # Windows has no `readline` normally
 | 
					 | 
				
			||||||
    pass
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from .sync.client import ClientConnection, connect
 | 
					 | 
				
			||||||
from .version import version as websockets_version
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
if sys.platform == "win32":
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def win_enable_vt100() -> None:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Enable VT-100 for console output on Windows.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        See also https://github.com/python/cpython/issues/73245.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        import ctypes
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        STD_OUTPUT_HANDLE = ctypes.c_uint(-11)
 | 
					 | 
				
			||||||
        INVALID_HANDLE_VALUE = ctypes.c_uint(-1)
 | 
					 | 
				
			||||||
        ENABLE_VIRTUAL_TERMINAL_PROCESSING = 0x004
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        handle = ctypes.windll.kernel32.GetStdHandle(STD_OUTPUT_HANDLE)
 | 
					 | 
				
			||||||
        if handle == INVALID_HANDLE_VALUE:
 | 
					 | 
				
			||||||
            raise RuntimeError("unable to obtain stdout handle")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        cur_mode = ctypes.c_uint()
 | 
					 | 
				
			||||||
        if ctypes.windll.kernel32.GetConsoleMode(handle, ctypes.byref(cur_mode)) == 0:
 | 
					 | 
				
			||||||
            raise RuntimeError("unable to query current console mode")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # ctypes ints lack support for the required bit-OR operation.
 | 
					 | 
				
			||||||
        # Temporarily convert to Py int, do the OR and convert back.
 | 
					 | 
				
			||||||
        py_int_mode = int.from_bytes(cur_mode, sys.byteorder)
 | 
					 | 
				
			||||||
        new_mode = ctypes.c_uint(py_int_mode | ENABLE_VIRTUAL_TERMINAL_PROCESSING)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if ctypes.windll.kernel32.SetConsoleMode(handle, new_mode) == 0:
 | 
					 | 
				
			||||||
            raise RuntimeError("unable to set console mode")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def print_during_input(string: str) -> None:
 | 
					 | 
				
			||||||
    sys.stdout.write(
 | 
					 | 
				
			||||||
        # Save cursor position
 | 
					 | 
				
			||||||
        "\N{ESC}7"
 | 
					 | 
				
			||||||
        # Add a new line
 | 
					 | 
				
			||||||
        "\N{LINE FEED}"
 | 
					 | 
				
			||||||
        # Move cursor up
 | 
					 | 
				
			||||||
        "\N{ESC}[A"
 | 
					 | 
				
			||||||
        # Insert blank line, scroll last line down
 | 
					 | 
				
			||||||
        "\N{ESC}[L"
 | 
					 | 
				
			||||||
        # Print string in the inserted blank line
 | 
					 | 
				
			||||||
        f"{string}\N{LINE FEED}"
 | 
					 | 
				
			||||||
        # Restore cursor position
 | 
					 | 
				
			||||||
        "\N{ESC}8"
 | 
					 | 
				
			||||||
        # Move cursor down
 | 
					 | 
				
			||||||
        "\N{ESC}[B"
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
    sys.stdout.flush()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def print_over_input(string: str) -> None:
 | 
					 | 
				
			||||||
    sys.stdout.write(
 | 
					 | 
				
			||||||
        # Move cursor to beginning of line
 | 
					 | 
				
			||||||
        "\N{CARRIAGE RETURN}"
 | 
					 | 
				
			||||||
        # Delete current line
 | 
					 | 
				
			||||||
        "\N{ESC}[K"
 | 
					 | 
				
			||||||
        # Print string
 | 
					 | 
				
			||||||
        f"{string}\N{LINE FEED}"
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
    sys.stdout.flush()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def print_incoming_messages(websocket: ClientConnection, stop: threading.Event) -> None:
 | 
					 | 
				
			||||||
    for message in websocket:
 | 
					 | 
				
			||||||
        if isinstance(message, str):
 | 
					 | 
				
			||||||
            print_during_input("< " + message)
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            print_during_input("< (binary) " + message.hex())
 | 
					 | 
				
			||||||
    if not stop.is_set():
 | 
					 | 
				
			||||||
        # When the server closes the connection, raise KeyboardInterrupt
 | 
					 | 
				
			||||||
        # in the main thread to exit the program.
 | 
					 | 
				
			||||||
        if sys.platform == "win32":
 | 
					 | 
				
			||||||
            ctrl_c = signal.CTRL_C_EVENT
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            ctrl_c = signal.SIGINT
 | 
					 | 
				
			||||||
        os.kill(os.getpid(), ctrl_c)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def main() -> None:
 | 
					 | 
				
			||||||
    # Parse command line arguments.
 | 
					 | 
				
			||||||
    parser = argparse.ArgumentParser(
 | 
					 | 
				
			||||||
        prog="python -m websockets",
 | 
					 | 
				
			||||||
        description="Interactive WebSocket client.",
 | 
					 | 
				
			||||||
        add_help=False,
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
    group = parser.add_mutually_exclusive_group()
 | 
					 | 
				
			||||||
    group.add_argument("--version", action="store_true")
 | 
					 | 
				
			||||||
    group.add_argument("uri", metavar="<uri>", nargs="?")
 | 
					 | 
				
			||||||
    args = parser.parse_args()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if args.version:
 | 
					 | 
				
			||||||
        print(f"websockets {websockets_version}")
 | 
					 | 
				
			||||||
        return
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if args.uri is None:
 | 
					 | 
				
			||||||
        parser.error("the following arguments are required: <uri>")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # If we're on Windows, enable VT100 terminal support.
 | 
					 | 
				
			||||||
    if sys.platform == "win32":
 | 
					 | 
				
			||||||
        try:
 | 
					 | 
				
			||||||
            win_enable_vt100()
 | 
					 | 
				
			||||||
        except RuntimeError as exc:
 | 
					 | 
				
			||||||
            sys.stderr.write(
 | 
					 | 
				
			||||||
                f"Unable to set terminal to VT100 mode. This is only "
 | 
					 | 
				
			||||||
                f"supported since Win10 anniversary update. Expect "
 | 
					 | 
				
			||||||
                f"weird symbols on the terminal.\nError: {exc}\n"
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
            sys.stderr.flush()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    try:
 | 
					 | 
				
			||||||
        websocket = connect(args.uri)
 | 
					 | 
				
			||||||
    except Exception as exc:
 | 
					 | 
				
			||||||
        print(f"Failed to connect to {args.uri}: {exc}.")
 | 
					 | 
				
			||||||
        sys.exit(1)
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        print(f"Connected to {args.uri}.")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    stop = threading.Event()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Start the thread that reads messages from the connection.
 | 
					 | 
				
			||||||
    thread = threading.Thread(target=print_incoming_messages, args=(websocket, stop))
 | 
					 | 
				
			||||||
    thread.start()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Read from stdin in the main thread in order to receive signals.
 | 
					 | 
				
			||||||
    try:
 | 
					 | 
				
			||||||
        while True:
 | 
					 | 
				
			||||||
            # Since there's no size limit, put_nowait is identical to put.
 | 
					 | 
				
			||||||
            message = input("> ")
 | 
					 | 
				
			||||||
            websocket.send(message)
 | 
					 | 
				
			||||||
    except (KeyboardInterrupt, EOFError):  # ^C, ^D
 | 
					 | 
				
			||||||
        stop.set()
 | 
					 | 
				
			||||||
        websocket.close()
 | 
					 | 
				
			||||||
        print_over_input("Connection closed.")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    thread.join()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
if __name__ == "__main__":
 | 
					 | 
				
			||||||
    main()
 | 
					 | 
				
			||||||
@@ -1,282 +0,0 @@
 | 
				
			|||||||
# From https://github.com/aio-libs/async-timeout/blob/master/async_timeout/__init__.py
 | 
					 | 
				
			||||||
# Licensed under the Apache License (Apache-2.0)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import asyncio
 | 
					 | 
				
			||||||
import enum
 | 
					 | 
				
			||||||
import sys
 | 
					 | 
				
			||||||
import warnings
 | 
					 | 
				
			||||||
from types import TracebackType
 | 
					 | 
				
			||||||
from typing import Optional, Type
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
if sys.version_info >= (3, 11):
 | 
					 | 
				
			||||||
    from typing import final
 | 
					 | 
				
			||||||
else:
 | 
					 | 
				
			||||||
    # From https://github.com/python/typing_extensions/blob/main/src/typing_extensions.py
 | 
					 | 
				
			||||||
    # Licensed under the Python Software Foundation License (PSF-2.0)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # @final exists in 3.8+, but we backport it for all versions
 | 
					 | 
				
			||||||
    # before 3.11 to keep support for the __final__ attribute.
 | 
					 | 
				
			||||||
    # See https://bugs.python.org/issue46342
 | 
					 | 
				
			||||||
    def final(f):
 | 
					 | 
				
			||||||
        """This decorator can be used to indicate to type checkers that
 | 
					 | 
				
			||||||
        the decorated method cannot be overridden, and decorated class
 | 
					 | 
				
			||||||
        cannot be subclassed. For example:
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            class Base:
 | 
					 | 
				
			||||||
                @final
 | 
					 | 
				
			||||||
                def done(self) -> None:
 | 
					 | 
				
			||||||
                    ...
 | 
					 | 
				
			||||||
            class Sub(Base):
 | 
					 | 
				
			||||||
                def done(self) -> None:  # Error reported by type checker
 | 
					 | 
				
			||||||
                    ...
 | 
					 | 
				
			||||||
            @final
 | 
					 | 
				
			||||||
            class Leaf:
 | 
					 | 
				
			||||||
                ...
 | 
					 | 
				
			||||||
            class Other(Leaf):  # Error reported by type checker
 | 
					 | 
				
			||||||
                ...
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        There is no runtime checking of these properties. The decorator
 | 
					 | 
				
			||||||
        sets the ``__final__`` attribute to ``True`` on the decorated object
 | 
					 | 
				
			||||||
        to allow runtime introspection.
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        try:
 | 
					 | 
				
			||||||
            f.__final__ = True
 | 
					 | 
				
			||||||
        except (AttributeError, TypeError):
 | 
					 | 
				
			||||||
            # Skip the attribute silently if it is not writable.
 | 
					 | 
				
			||||||
            # AttributeError happens if the object has __slots__ or a
 | 
					 | 
				
			||||||
            # read-only property, TypeError if it's a builtin class.
 | 
					 | 
				
			||||||
            pass
 | 
					 | 
				
			||||||
        return f
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # End https://github.com/python/typing_extensions/blob/main/src/typing_extensions.py
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
if sys.version_info >= (3, 11):
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def _uncancel_task(task: "asyncio.Task[object]") -> None:
 | 
					 | 
				
			||||||
        task.uncancel()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
else:
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def _uncancel_task(task: "asyncio.Task[object]") -> None:
 | 
					 | 
				
			||||||
        pass
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
__version__ = "4.0.3"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
__all__ = ("timeout", "timeout_at", "Timeout")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def timeout(delay: Optional[float]) -> "Timeout":
 | 
					 | 
				
			||||||
    """timeout context manager.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Useful in cases when you want to apply timeout logic around block
 | 
					 | 
				
			||||||
    of code or in cases when asyncio.wait_for is not suitable. For example:
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    >>> async with timeout(0.001):
 | 
					 | 
				
			||||||
    ...     async with aiohttp.get('https://github.com') as r:
 | 
					 | 
				
			||||||
    ...         await r.text()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    delay - value in seconds or None to disable timeout logic
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    loop = asyncio.get_running_loop()
 | 
					 | 
				
			||||||
    if delay is not None:
 | 
					 | 
				
			||||||
        deadline = loop.time() + delay  # type: Optional[float]
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        deadline = None
 | 
					 | 
				
			||||||
    return Timeout(deadline, loop)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def timeout_at(deadline: Optional[float]) -> "Timeout":
 | 
					 | 
				
			||||||
    """Schedule the timeout at absolute time.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    deadline argument points on the time in the same clock system
 | 
					 | 
				
			||||||
    as loop.time().
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Please note: it is not POSIX time but a time with
 | 
					 | 
				
			||||||
    undefined starting base, e.g. the time of the system power on.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    >>> async with timeout_at(loop.time() + 10):
 | 
					 | 
				
			||||||
    ...     async with aiohttp.get('https://github.com') as r:
 | 
					 | 
				
			||||||
    ...         await r.text()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    loop = asyncio.get_running_loop()
 | 
					 | 
				
			||||||
    return Timeout(deadline, loop)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class _State(enum.Enum):
 | 
					 | 
				
			||||||
    INIT = "INIT"
 | 
					 | 
				
			||||||
    ENTER = "ENTER"
 | 
					 | 
				
			||||||
    TIMEOUT = "TIMEOUT"
 | 
					 | 
				
			||||||
    EXIT = "EXIT"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@final
 | 
					 | 
				
			||||||
class Timeout:
 | 
					 | 
				
			||||||
    # Internal class, please don't instantiate it directly
 | 
					 | 
				
			||||||
    # Use timeout() and timeout_at() public factories instead.
 | 
					 | 
				
			||||||
    #
 | 
					 | 
				
			||||||
    # Implementation note: `async with timeout()` is preferred
 | 
					 | 
				
			||||||
    # over `with timeout()`.
 | 
					 | 
				
			||||||
    # While technically the Timeout class implementation
 | 
					 | 
				
			||||||
    # doesn't need to be async at all,
 | 
					 | 
				
			||||||
    # the `async with` statement explicitly points that
 | 
					 | 
				
			||||||
    # the context manager should be used from async function context.
 | 
					 | 
				
			||||||
    #
 | 
					 | 
				
			||||||
    # This design allows to avoid many silly misusages.
 | 
					 | 
				
			||||||
    #
 | 
					 | 
				
			||||||
    # TimeoutError is raised immediately when scheduled
 | 
					 | 
				
			||||||
    # if the deadline is passed.
 | 
					 | 
				
			||||||
    # The purpose is to time out as soon as possible
 | 
					 | 
				
			||||||
    # without waiting for the next await expression.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    __slots__ = ("_deadline", "_loop", "_state", "_timeout_handler", "_task")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(
 | 
					 | 
				
			||||||
        self, deadline: Optional[float], loop: asyncio.AbstractEventLoop
 | 
					 | 
				
			||||||
    ) -> None:
 | 
					 | 
				
			||||||
        self._loop = loop
 | 
					 | 
				
			||||||
        self._state = _State.INIT
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self._task: Optional["asyncio.Task[object]"] = None
 | 
					 | 
				
			||||||
        self._timeout_handler = None  # type: Optional[asyncio.Handle]
 | 
					 | 
				
			||||||
        if deadline is None:
 | 
					 | 
				
			||||||
            self._deadline = None  # type: Optional[float]
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            self.update(deadline)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __enter__(self) -> "Timeout":
 | 
					 | 
				
			||||||
        warnings.warn(
 | 
					 | 
				
			||||||
            "with timeout() is deprecated, use async with timeout() instead",
 | 
					 | 
				
			||||||
            DeprecationWarning,
 | 
					 | 
				
			||||||
            stacklevel=2,
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        self._do_enter()
 | 
					 | 
				
			||||||
        return self
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __exit__(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        exc_type: Optional[Type[BaseException]],
 | 
					 | 
				
			||||||
        exc_val: Optional[BaseException],
 | 
					 | 
				
			||||||
        exc_tb: Optional[TracebackType],
 | 
					 | 
				
			||||||
    ) -> Optional[bool]:
 | 
					 | 
				
			||||||
        self._do_exit(exc_type)
 | 
					 | 
				
			||||||
        return None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    async def __aenter__(self) -> "Timeout":
 | 
					 | 
				
			||||||
        self._do_enter()
 | 
					 | 
				
			||||||
        return self
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    async def __aexit__(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        exc_type: Optional[Type[BaseException]],
 | 
					 | 
				
			||||||
        exc_val: Optional[BaseException],
 | 
					 | 
				
			||||||
        exc_tb: Optional[TracebackType],
 | 
					 | 
				
			||||||
    ) -> Optional[bool]:
 | 
					 | 
				
			||||||
        self._do_exit(exc_type)
 | 
					 | 
				
			||||||
        return None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @property
 | 
					 | 
				
			||||||
    def expired(self) -> bool:
 | 
					 | 
				
			||||||
        """Is timeout expired during execution?"""
 | 
					 | 
				
			||||||
        return self._state == _State.TIMEOUT
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @property
 | 
					 | 
				
			||||||
    def deadline(self) -> Optional[float]:
 | 
					 | 
				
			||||||
        return self._deadline
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def reject(self) -> None:
 | 
					 | 
				
			||||||
        """Reject scheduled timeout if any."""
 | 
					 | 
				
			||||||
        # cancel is maybe better name but
 | 
					 | 
				
			||||||
        # task.cancel() raises CancelledError in asyncio world.
 | 
					 | 
				
			||||||
        if self._state not in (_State.INIT, _State.ENTER):
 | 
					 | 
				
			||||||
            raise RuntimeError(f"invalid state {self._state.value}")
 | 
					 | 
				
			||||||
        self._reject()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def _reject(self) -> None:
 | 
					 | 
				
			||||||
        self._task = None
 | 
					 | 
				
			||||||
        if self._timeout_handler is not None:
 | 
					 | 
				
			||||||
            self._timeout_handler.cancel()
 | 
					 | 
				
			||||||
            self._timeout_handler = None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def shift(self, delay: float) -> None:
 | 
					 | 
				
			||||||
        """Advance timeout on delay seconds.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        The delay can be negative.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Raise RuntimeError if shift is called when deadline is not scheduled
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        deadline = self._deadline
 | 
					 | 
				
			||||||
        if deadline is None:
 | 
					 | 
				
			||||||
            raise RuntimeError("cannot shift timeout if deadline is not scheduled")
 | 
					 | 
				
			||||||
        self.update(deadline + delay)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def update(self, deadline: float) -> None:
 | 
					 | 
				
			||||||
        """Set deadline to absolute value.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        deadline argument points on the time in the same clock system
 | 
					 | 
				
			||||||
        as loop.time().
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        If new deadline is in the past the timeout is raised immediately.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Please note: it is not POSIX time but a time with
 | 
					 | 
				
			||||||
        undefined starting base, e.g. the time of the system power on.
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        if self._state == _State.EXIT:
 | 
					 | 
				
			||||||
            raise RuntimeError("cannot reschedule after exit from context manager")
 | 
					 | 
				
			||||||
        if self._state == _State.TIMEOUT:
 | 
					 | 
				
			||||||
            raise RuntimeError("cannot reschedule expired timeout")
 | 
					 | 
				
			||||||
        if self._timeout_handler is not None:
 | 
					 | 
				
			||||||
            self._timeout_handler.cancel()
 | 
					 | 
				
			||||||
        self._deadline = deadline
 | 
					 | 
				
			||||||
        if self._state != _State.INIT:
 | 
					 | 
				
			||||||
            self._reschedule()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def _reschedule(self) -> None:
 | 
					 | 
				
			||||||
        assert self._state == _State.ENTER
 | 
					 | 
				
			||||||
        deadline = self._deadline
 | 
					 | 
				
			||||||
        if deadline is None:
 | 
					 | 
				
			||||||
            return
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        now = self._loop.time()
 | 
					 | 
				
			||||||
        if self._timeout_handler is not None:
 | 
					 | 
				
			||||||
            self._timeout_handler.cancel()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self._task = asyncio.current_task()
 | 
					 | 
				
			||||||
        if deadline <= now:
 | 
					 | 
				
			||||||
            self._timeout_handler = self._loop.call_soon(self._on_timeout)
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            self._timeout_handler = self._loop.call_at(deadline, self._on_timeout)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def _do_enter(self) -> None:
 | 
					 | 
				
			||||||
        if self._state != _State.INIT:
 | 
					 | 
				
			||||||
            raise RuntimeError(f"invalid state {self._state.value}")
 | 
					 | 
				
			||||||
        self._state = _State.ENTER
 | 
					 | 
				
			||||||
        self._reschedule()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def _do_exit(self, exc_type: Optional[Type[BaseException]]) -> None:
 | 
					 | 
				
			||||||
        if exc_type is asyncio.CancelledError and self._state == _State.TIMEOUT:
 | 
					 | 
				
			||||||
            assert self._task is not None
 | 
					 | 
				
			||||||
            _uncancel_task(self._task)
 | 
					 | 
				
			||||||
            self._timeout_handler = None
 | 
					 | 
				
			||||||
            self._task = None
 | 
					 | 
				
			||||||
            raise asyncio.TimeoutError
 | 
					 | 
				
			||||||
        # timeout has not expired
 | 
					 | 
				
			||||||
        self._state = _State.EXIT
 | 
					 | 
				
			||||||
        self._reject()
 | 
					 | 
				
			||||||
        return None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def _on_timeout(self) -> None:
 | 
					 | 
				
			||||||
        assert self._task is not None
 | 
					 | 
				
			||||||
        self._task.cancel()
 | 
					 | 
				
			||||||
        self._state = _State.TIMEOUT
 | 
					 | 
				
			||||||
        # drop the reference early
 | 
					 | 
				
			||||||
        self._timeout_handler = None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# End https://github.com/aio-libs/async-timeout/blob/master/async_timeout/__init__.py
 | 
					 | 
				
			||||||
@@ -1,561 +0,0 @@
 | 
				
			|||||||
from __future__ import annotations
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import asyncio
 | 
					 | 
				
			||||||
import logging
 | 
					 | 
				
			||||||
import os
 | 
					 | 
				
			||||||
import urllib.parse
 | 
					 | 
				
			||||||
from types import TracebackType
 | 
					 | 
				
			||||||
from typing import Any, AsyncIterator, Callable, Generator, Sequence
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from ..client import ClientProtocol, backoff
 | 
					 | 
				
			||||||
from ..datastructures import HeadersLike
 | 
					 | 
				
			||||||
from ..exceptions import InvalidStatus, SecurityError
 | 
					 | 
				
			||||||
from ..extensions.base import ClientExtensionFactory
 | 
					 | 
				
			||||||
from ..extensions.permessage_deflate import enable_client_permessage_deflate
 | 
					 | 
				
			||||||
from ..headers import validate_subprotocols
 | 
					 | 
				
			||||||
from ..http11 import USER_AGENT, Response
 | 
					 | 
				
			||||||
from ..protocol import CONNECTING, Event
 | 
					 | 
				
			||||||
from ..typing import LoggerLike, Origin, Subprotocol
 | 
					 | 
				
			||||||
from ..uri import WebSocketURI, parse_uri
 | 
					 | 
				
			||||||
from .compatibility import TimeoutError, asyncio_timeout
 | 
					 | 
				
			||||||
from .connection import Connection
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
__all__ = ["connect", "unix_connect", "ClientConnection"]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
MAX_REDIRECTS = int(os.environ.get("WEBSOCKETS_MAX_REDIRECTS", "10"))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class ClientConnection(Connection):
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    :mod:`asyncio` implementation of a WebSocket client connection.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    :class:`ClientConnection` provides :meth:`recv` and :meth:`send` coroutines
 | 
					 | 
				
			||||||
    for receiving and sending messages.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    It supports asynchronous iteration to receive messages::
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        async for message in websocket:
 | 
					 | 
				
			||||||
            await process(message)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    The iterator exits normally when the connection is closed with close code
 | 
					 | 
				
			||||||
    1000 (OK) or 1001 (going away) or without a close code. It raises a
 | 
					 | 
				
			||||||
    :exc:`~websockets.exceptions.ConnectionClosedError` when the connection is
 | 
					 | 
				
			||||||
    closed with any other code.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    The ``ping_interval``, ``ping_timeout``, ``close_timeout``, ``max_queue``,
 | 
					 | 
				
			||||||
    and ``write_limit`` arguments the same meaning as in :func:`connect`.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Args:
 | 
					 | 
				
			||||||
        protocol: Sans-I/O connection.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        protocol: ClientProtocol,
 | 
					 | 
				
			||||||
        *,
 | 
					 | 
				
			||||||
        ping_interval: float | None = 20,
 | 
					 | 
				
			||||||
        ping_timeout: float | None = 20,
 | 
					 | 
				
			||||||
        close_timeout: float | None = 10,
 | 
					 | 
				
			||||||
        max_queue: int | tuple[int, int | None] = 16,
 | 
					 | 
				
			||||||
        write_limit: int | tuple[int, int | None] = 2**15,
 | 
					 | 
				
			||||||
    ) -> None:
 | 
					 | 
				
			||||||
        self.protocol: ClientProtocol
 | 
					 | 
				
			||||||
        super().__init__(
 | 
					 | 
				
			||||||
            protocol,
 | 
					 | 
				
			||||||
            ping_interval=ping_interval,
 | 
					 | 
				
			||||||
            ping_timeout=ping_timeout,
 | 
					 | 
				
			||||||
            close_timeout=close_timeout,
 | 
					 | 
				
			||||||
            max_queue=max_queue,
 | 
					 | 
				
			||||||
            write_limit=write_limit,
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        self.response_rcvd: asyncio.Future[None] = self.loop.create_future()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    async def handshake(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        additional_headers: HeadersLike | None = None,
 | 
					 | 
				
			||||||
        user_agent_header: str | None = USER_AGENT,
 | 
					 | 
				
			||||||
    ) -> None:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Perform the opening handshake.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        async with self.send_context(expected_state=CONNECTING):
 | 
					 | 
				
			||||||
            self.request = self.protocol.connect()
 | 
					 | 
				
			||||||
            if additional_headers is not None:
 | 
					 | 
				
			||||||
                self.request.headers.update(additional_headers)
 | 
					 | 
				
			||||||
            if user_agent_header:
 | 
					 | 
				
			||||||
                self.request.headers["User-Agent"] = user_agent_header
 | 
					 | 
				
			||||||
            self.protocol.send_request(self.request)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        await asyncio.wait(
 | 
					 | 
				
			||||||
            [self.response_rcvd, self.connection_lost_waiter],
 | 
					 | 
				
			||||||
            return_when=asyncio.FIRST_COMPLETED,
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # self.protocol.handshake_exc is always set when the connection is lost
 | 
					 | 
				
			||||||
        # before receiving a response, when the response cannot be parsed, or
 | 
					 | 
				
			||||||
        # when the response fails the handshake.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if self.protocol.handshake_exc is not None:
 | 
					 | 
				
			||||||
            raise self.protocol.handshake_exc
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def process_event(self, event: Event) -> None:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Process one incoming event.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        # First event - handshake response.
 | 
					 | 
				
			||||||
        if self.response is None:
 | 
					 | 
				
			||||||
            assert isinstance(event, Response)
 | 
					 | 
				
			||||||
            self.response = event
 | 
					 | 
				
			||||||
            self.response_rcvd.set_result(None)
 | 
					 | 
				
			||||||
        # Later events - frames.
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            super().process_event(event)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def process_exception(exc: Exception) -> Exception | None:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Determine whether a connection error is retryable or fatal.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    When reconnecting automatically with ``async for ... in connect(...)``, if a
 | 
					 | 
				
			||||||
    connection attempt fails, :func:`process_exception` is called to determine
 | 
					 | 
				
			||||||
    whether to retry connecting or to raise the exception.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    This function defines the default behavior, which is to retry on:
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    * :exc:`EOFError`, :exc:`OSError`, :exc:`asyncio.TimeoutError`: network
 | 
					 | 
				
			||||||
      errors;
 | 
					 | 
				
			||||||
    * :exc:`~websockets.exceptions.InvalidStatus` when the status code is 500,
 | 
					 | 
				
			||||||
      502, 503, or 504: server or proxy errors.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    All other exceptions are considered fatal.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    You can change this behavior with the ``process_exception`` argument of
 | 
					 | 
				
			||||||
    :func:`connect`.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Return :obj:`None` if the exception is retryable i.e. when the error could
 | 
					 | 
				
			||||||
    be transient and trying to reconnect with the same parameters could succeed.
 | 
					 | 
				
			||||||
    The exception will be logged at the ``INFO`` level.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Return an exception, either ``exc`` or a new exception, if the exception is
 | 
					 | 
				
			||||||
    fatal i.e. when trying to reconnect will most likely produce the same error.
 | 
					 | 
				
			||||||
    That exception will be raised, breaking out of the retry loop.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    if isinstance(exc, (EOFError, OSError, asyncio.TimeoutError)):
 | 
					 | 
				
			||||||
        return None
 | 
					 | 
				
			||||||
    if isinstance(exc, InvalidStatus) and exc.response.status_code in [
 | 
					 | 
				
			||||||
        500,  # Internal Server Error
 | 
					 | 
				
			||||||
        502,  # Bad Gateway
 | 
					 | 
				
			||||||
        503,  # Service Unavailable
 | 
					 | 
				
			||||||
        504,  # Gateway Timeout
 | 
					 | 
				
			||||||
    ]:
 | 
					 | 
				
			||||||
        return None
 | 
					 | 
				
			||||||
    return exc
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# This is spelled in lower case because it's exposed as a callable in the API.
 | 
					 | 
				
			||||||
class connect:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Connect to the WebSocket server at ``uri``.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    This coroutine returns a :class:`ClientConnection` instance, which you can
 | 
					 | 
				
			||||||
    use to send and receive messages.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    :func:`connect` may be used as an asynchronous context manager::
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        from websockets.asyncio.client import connect
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        async with connect(...) as websocket:
 | 
					 | 
				
			||||||
            ...
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    The connection is closed automatically when exiting the context.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    :func:`connect` can be used as an infinite asynchronous iterator to
 | 
					 | 
				
			||||||
    reconnect automatically on errors::
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        async for websocket in connect(...):
 | 
					 | 
				
			||||||
            try:
 | 
					 | 
				
			||||||
                ...
 | 
					 | 
				
			||||||
            except websockets.ConnectionClosed:
 | 
					 | 
				
			||||||
                continue
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    If the connection fails with a transient error, it is retried with
 | 
					 | 
				
			||||||
    exponential backoff. If it fails with a fatal error, the exception is
 | 
					 | 
				
			||||||
    raised, breaking out of the loop.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    The connection is closed automatically after each iteration of the loop.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Args:
 | 
					 | 
				
			||||||
        uri: URI of the WebSocket server.
 | 
					 | 
				
			||||||
        origin: Value of the ``Origin`` header, for servers that require it.
 | 
					 | 
				
			||||||
        extensions: List of supported extensions, in order in which they
 | 
					 | 
				
			||||||
            should be negotiated and run.
 | 
					 | 
				
			||||||
        subprotocols: List of supported subprotocols, in order of decreasing
 | 
					 | 
				
			||||||
            preference.
 | 
					 | 
				
			||||||
        additional_headers (HeadersLike | None): Arbitrary HTTP headers to add
 | 
					 | 
				
			||||||
            to the handshake request.
 | 
					 | 
				
			||||||
        user_agent_header: Value of  the ``User-Agent`` request header.
 | 
					 | 
				
			||||||
            It defaults to ``"Python/x.y.z websockets/X.Y"``.
 | 
					 | 
				
			||||||
            Setting it to :obj:`None` removes the header.
 | 
					 | 
				
			||||||
        compression: The "permessage-deflate" extension is enabled by default.
 | 
					 | 
				
			||||||
            Set ``compression`` to :obj:`None` to disable it. See the
 | 
					 | 
				
			||||||
            :doc:`compression guide <../../topics/compression>` for details.
 | 
					 | 
				
			||||||
        process_exception: When reconnecting automatically, tell whether an
 | 
					 | 
				
			||||||
            error is transient or fatal. The default behavior is defined by
 | 
					 | 
				
			||||||
            :func:`process_exception`. Refer to its documentation for details.
 | 
					 | 
				
			||||||
        open_timeout: Timeout for opening the connection in seconds.
 | 
					 | 
				
			||||||
            :obj:`None` disables the timeout.
 | 
					 | 
				
			||||||
        ping_interval: Interval between keepalive pings in seconds.
 | 
					 | 
				
			||||||
            :obj:`None` disables keepalive.
 | 
					 | 
				
			||||||
        ping_timeout: Timeout for keepalive pings in seconds.
 | 
					 | 
				
			||||||
            :obj:`None` disables timeouts.
 | 
					 | 
				
			||||||
        close_timeout: Timeout for closing the connection in seconds.
 | 
					 | 
				
			||||||
            :obj:`None` disables the timeout.
 | 
					 | 
				
			||||||
        max_size: Maximum size of incoming messages in bytes.
 | 
					 | 
				
			||||||
            :obj:`None` disables the limit.
 | 
					 | 
				
			||||||
        max_queue: High-water mark of the buffer where frames are received.
 | 
					 | 
				
			||||||
            It defaults to 16 frames. The low-water mark defaults to ``max_queue
 | 
					 | 
				
			||||||
            // 4``. You may pass a ``(high, low)`` tuple to set the high-water
 | 
					 | 
				
			||||||
            and low-water marks.
 | 
					 | 
				
			||||||
        write_limit: High-water mark of write buffer in bytes. It is passed to
 | 
					 | 
				
			||||||
            :meth:`~asyncio.WriteTransport.set_write_buffer_limits`. It defaults
 | 
					 | 
				
			||||||
            to 32 KiB. You may pass a ``(high, low)`` tuple to set the
 | 
					 | 
				
			||||||
            high-water and low-water marks.
 | 
					 | 
				
			||||||
        logger: Logger for this client.
 | 
					 | 
				
			||||||
            It defaults to ``logging.getLogger("websockets.client")``.
 | 
					 | 
				
			||||||
            See the :doc:`logging guide <../../topics/logging>` for details.
 | 
					 | 
				
			||||||
        create_connection: Factory for the :class:`ClientConnection` managing
 | 
					 | 
				
			||||||
            the connection. Set it to a wrapper or a subclass to customize
 | 
					 | 
				
			||||||
            connection handling.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Any other keyword arguments are passed to the event loop's
 | 
					 | 
				
			||||||
    :meth:`~asyncio.loop.create_connection` method.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    For example:
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    * You can set ``ssl`` to a :class:`~ssl.SSLContext` to enforce TLS settings.
 | 
					 | 
				
			||||||
      When connecting to a ``wss://`` URI, if ``ssl`` isn't provided, a TLS
 | 
					 | 
				
			||||||
      context is created with :func:`~ssl.create_default_context`.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    * You can set ``server_hostname`` to override the host name from ``uri`` in
 | 
					 | 
				
			||||||
      the TLS handshake.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    * You can set ``host`` and ``port`` to connect to a different host and port
 | 
					 | 
				
			||||||
      from those found in ``uri``. This only changes the destination of the TCP
 | 
					 | 
				
			||||||
      connection. The host name from ``uri`` is still used in the TLS handshake
 | 
					 | 
				
			||||||
      for secure connections and in the ``Host`` header.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    * You can set ``sock`` to provide a preexisting TCP socket. You may call
 | 
					 | 
				
			||||||
      :func:`socket.create_connection` (not to be confused with the event loop's
 | 
					 | 
				
			||||||
      :meth:`~asyncio.loop.create_connection` method) to create a suitable
 | 
					 | 
				
			||||||
      client socket and customize it.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Raises:
 | 
					 | 
				
			||||||
        InvalidURI: If ``uri`` isn't a valid WebSocket URI.
 | 
					 | 
				
			||||||
        OSError: If the TCP connection fails.
 | 
					 | 
				
			||||||
        InvalidHandshake: If the opening handshake fails.
 | 
					 | 
				
			||||||
        TimeoutError: If the opening handshake times out.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        uri: str,
 | 
					 | 
				
			||||||
        *,
 | 
					 | 
				
			||||||
        # WebSocket
 | 
					 | 
				
			||||||
        origin: Origin | None = None,
 | 
					 | 
				
			||||||
        extensions: Sequence[ClientExtensionFactory] | None = None,
 | 
					 | 
				
			||||||
        subprotocols: Sequence[Subprotocol] | None = None,
 | 
					 | 
				
			||||||
        additional_headers: HeadersLike | None = None,
 | 
					 | 
				
			||||||
        user_agent_header: str | None = USER_AGENT,
 | 
					 | 
				
			||||||
        compression: str | None = "deflate",
 | 
					 | 
				
			||||||
        process_exception: Callable[[Exception], Exception | None] = process_exception,
 | 
					 | 
				
			||||||
        # Timeouts
 | 
					 | 
				
			||||||
        open_timeout: float | None = 10,
 | 
					 | 
				
			||||||
        ping_interval: float | None = 20,
 | 
					 | 
				
			||||||
        ping_timeout: float | None = 20,
 | 
					 | 
				
			||||||
        close_timeout: float | None = 10,
 | 
					 | 
				
			||||||
        # Limits
 | 
					 | 
				
			||||||
        max_size: int | None = 2**20,
 | 
					 | 
				
			||||||
        max_queue: int | tuple[int, int | None] = 16,
 | 
					 | 
				
			||||||
        write_limit: int | tuple[int, int | None] = 2**15,
 | 
					 | 
				
			||||||
        # Logging
 | 
					 | 
				
			||||||
        logger: LoggerLike | None = None,
 | 
					 | 
				
			||||||
        # Escape hatch for advanced customization
 | 
					 | 
				
			||||||
        create_connection: type[ClientConnection] | None = None,
 | 
					 | 
				
			||||||
        # Other keyword arguments are passed to loop.create_connection
 | 
					 | 
				
			||||||
        **kwargs: Any,
 | 
					 | 
				
			||||||
    ) -> None:
 | 
					 | 
				
			||||||
        self.uri = uri
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if subprotocols is not None:
 | 
					 | 
				
			||||||
            validate_subprotocols(subprotocols)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if compression == "deflate":
 | 
					 | 
				
			||||||
            extensions = enable_client_permessage_deflate(extensions)
 | 
					 | 
				
			||||||
        elif compression is not None:
 | 
					 | 
				
			||||||
            raise ValueError(f"unsupported compression: {compression}")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if logger is None:
 | 
					 | 
				
			||||||
            logger = logging.getLogger("websockets.client")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if create_connection is None:
 | 
					 | 
				
			||||||
            create_connection = ClientConnection
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        def protocol_factory(wsuri: WebSocketURI) -> ClientConnection:
 | 
					 | 
				
			||||||
            # This is a protocol in the Sans-I/O implementation of websockets.
 | 
					 | 
				
			||||||
            protocol = ClientProtocol(
 | 
					 | 
				
			||||||
                wsuri,
 | 
					 | 
				
			||||||
                origin=origin,
 | 
					 | 
				
			||||||
                extensions=extensions,
 | 
					 | 
				
			||||||
                subprotocols=subprotocols,
 | 
					 | 
				
			||||||
                max_size=max_size,
 | 
					 | 
				
			||||||
                logger=logger,
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
            # This is a connection in websockets and a protocol in asyncio.
 | 
					 | 
				
			||||||
            connection = create_connection(
 | 
					 | 
				
			||||||
                protocol,
 | 
					 | 
				
			||||||
                ping_interval=ping_interval,
 | 
					 | 
				
			||||||
                ping_timeout=ping_timeout,
 | 
					 | 
				
			||||||
                close_timeout=close_timeout,
 | 
					 | 
				
			||||||
                max_queue=max_queue,
 | 
					 | 
				
			||||||
                write_limit=write_limit,
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
            return connection
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.protocol_factory = protocol_factory
 | 
					 | 
				
			||||||
        self.handshake_args = (
 | 
					 | 
				
			||||||
            additional_headers,
 | 
					 | 
				
			||||||
            user_agent_header,
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        self.process_exception = process_exception
 | 
					 | 
				
			||||||
        self.open_timeout = open_timeout
 | 
					 | 
				
			||||||
        self.logger = logger
 | 
					 | 
				
			||||||
        self.connection_kwargs = kwargs
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    async def create_connection(self) -> ClientConnection:
 | 
					 | 
				
			||||||
        """Create TCP or Unix connection."""
 | 
					 | 
				
			||||||
        loop = asyncio.get_running_loop()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        wsuri = parse_uri(self.uri)
 | 
					 | 
				
			||||||
        kwargs = self.connection_kwargs.copy()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        def factory() -> ClientConnection:
 | 
					 | 
				
			||||||
            return self.protocol_factory(wsuri)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if wsuri.secure:
 | 
					 | 
				
			||||||
            kwargs.setdefault("ssl", True)
 | 
					 | 
				
			||||||
            kwargs.setdefault("server_hostname", wsuri.host)
 | 
					 | 
				
			||||||
            if kwargs.get("ssl") is None:
 | 
					 | 
				
			||||||
                raise TypeError("ssl=None is incompatible with a wss:// URI")
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            if kwargs.get("ssl") is not None:
 | 
					 | 
				
			||||||
                raise TypeError("ssl argument is incompatible with a ws:// URI")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if kwargs.pop("unix", False):
 | 
					 | 
				
			||||||
            _, connection = await loop.create_unix_connection(factory, **kwargs)
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            if kwargs.get("sock") is None:
 | 
					 | 
				
			||||||
                kwargs.setdefault("host", wsuri.host)
 | 
					 | 
				
			||||||
                kwargs.setdefault("port", wsuri.port)
 | 
					 | 
				
			||||||
            _, connection = await loop.create_connection(factory, **kwargs)
 | 
					 | 
				
			||||||
        return connection
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def process_redirect(self, exc: Exception) -> Exception | str:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Determine whether a connection error is a redirect that can be followed.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Return the new URI if it's a valid redirect. Else, return an exception.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        if not (
 | 
					 | 
				
			||||||
            isinstance(exc, InvalidStatus)
 | 
					 | 
				
			||||||
            and exc.response.status_code
 | 
					 | 
				
			||||||
            in [
 | 
					 | 
				
			||||||
                300,  # Multiple Choices
 | 
					 | 
				
			||||||
                301,  # Moved Permanently
 | 
					 | 
				
			||||||
                302,  # Found
 | 
					 | 
				
			||||||
                303,  # See Other
 | 
					 | 
				
			||||||
                307,  # Temporary Redirect
 | 
					 | 
				
			||||||
                308,  # Permanent Redirect
 | 
					 | 
				
			||||||
            ]
 | 
					 | 
				
			||||||
            and "Location" in exc.response.headers
 | 
					 | 
				
			||||||
        ):
 | 
					 | 
				
			||||||
            return exc
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        old_wsuri = parse_uri(self.uri)
 | 
					 | 
				
			||||||
        new_uri = urllib.parse.urljoin(self.uri, exc.response.headers["Location"])
 | 
					 | 
				
			||||||
        new_wsuri = parse_uri(new_uri)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # If connect() received a socket, it is closed and cannot be reused.
 | 
					 | 
				
			||||||
        if self.connection_kwargs.get("sock") is not None:
 | 
					 | 
				
			||||||
            return ValueError(
 | 
					 | 
				
			||||||
                f"cannot follow redirect to {new_uri} with a preexisting socket"
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # TLS downgrade is forbidden.
 | 
					 | 
				
			||||||
        if old_wsuri.secure and not new_wsuri.secure:
 | 
					 | 
				
			||||||
            return SecurityError(f"cannot follow redirect to non-secure URI {new_uri}")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Apply restrictions to cross-origin redirects.
 | 
					 | 
				
			||||||
        if (
 | 
					 | 
				
			||||||
            old_wsuri.secure != new_wsuri.secure
 | 
					 | 
				
			||||||
            or old_wsuri.host != new_wsuri.host
 | 
					 | 
				
			||||||
            or old_wsuri.port != new_wsuri.port
 | 
					 | 
				
			||||||
        ):
 | 
					 | 
				
			||||||
            # Cross-origin redirects on Unix sockets don't quite make sense.
 | 
					 | 
				
			||||||
            if self.connection_kwargs.get("unix", False):
 | 
					 | 
				
			||||||
                return ValueError(
 | 
					 | 
				
			||||||
                    f"cannot follow cross-origin redirect to {new_uri} "
 | 
					 | 
				
			||||||
                    f"with a Unix socket"
 | 
					 | 
				
			||||||
                )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            # Cross-origin redirects when host and port are overridden are ill-defined.
 | 
					 | 
				
			||||||
            if (
 | 
					 | 
				
			||||||
                self.connection_kwargs.get("host") is not None
 | 
					 | 
				
			||||||
                or self.connection_kwargs.get("port") is not None
 | 
					 | 
				
			||||||
            ):
 | 
					 | 
				
			||||||
                return ValueError(
 | 
					 | 
				
			||||||
                    f"cannot follow cross-origin redirect to {new_uri} "
 | 
					 | 
				
			||||||
                    f"with an explicit host or port"
 | 
					 | 
				
			||||||
                )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return new_uri
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # ... = await connect(...)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __await__(self) -> Generator[Any, None, ClientConnection]:
 | 
					 | 
				
			||||||
        # Create a suitable iterator by calling __await__ on a coroutine.
 | 
					 | 
				
			||||||
        return self.__await_impl__().__await__()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    async def __await_impl__(self) -> ClientConnection:
 | 
					 | 
				
			||||||
        try:
 | 
					 | 
				
			||||||
            async with asyncio_timeout(self.open_timeout):
 | 
					 | 
				
			||||||
                for _ in range(MAX_REDIRECTS):
 | 
					 | 
				
			||||||
                    self.connection = await self.create_connection()
 | 
					 | 
				
			||||||
                    try:
 | 
					 | 
				
			||||||
                        await self.connection.handshake(*self.handshake_args)
 | 
					 | 
				
			||||||
                    except asyncio.CancelledError:
 | 
					 | 
				
			||||||
                        self.connection.close_transport()
 | 
					 | 
				
			||||||
                        raise
 | 
					 | 
				
			||||||
                    except Exception as exc:
 | 
					 | 
				
			||||||
                        # Always close the connection even though keep-alive is
 | 
					 | 
				
			||||||
                        # the default in HTTP/1.1 because create_connection ties
 | 
					 | 
				
			||||||
                        # opening the network connection with initializing the
 | 
					 | 
				
			||||||
                        # protocol. In the current design of connect(), there is
 | 
					 | 
				
			||||||
                        # no easy way to reuse the network connection that works
 | 
					 | 
				
			||||||
                        # in every case nor to reinitialize the protocol.
 | 
					 | 
				
			||||||
                        self.connection.close_transport()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                        uri_or_exc = self.process_redirect(exc)
 | 
					 | 
				
			||||||
                        # Response is a valid redirect; follow it.
 | 
					 | 
				
			||||||
                        if isinstance(uri_or_exc, str):
 | 
					 | 
				
			||||||
                            self.uri = uri_or_exc
 | 
					 | 
				
			||||||
                            continue
 | 
					 | 
				
			||||||
                        # Response isn't a valid redirect; raise the exception.
 | 
					 | 
				
			||||||
                        if uri_or_exc is exc:
 | 
					 | 
				
			||||||
                            raise
 | 
					 | 
				
			||||||
                        else:
 | 
					 | 
				
			||||||
                            raise uri_or_exc from exc
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                    else:
 | 
					 | 
				
			||||||
                        self.connection.start_keepalive()
 | 
					 | 
				
			||||||
                        return self.connection
 | 
					 | 
				
			||||||
                else:
 | 
					 | 
				
			||||||
                    raise SecurityError(f"more than {MAX_REDIRECTS} redirects")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        except TimeoutError:
 | 
					 | 
				
			||||||
            # Re-raise exception with an informative error message.
 | 
					 | 
				
			||||||
            raise TimeoutError("timed out during handshake") from None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # ... = yield from connect(...) - remove when dropping Python < 3.10
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    __iter__ = __await__
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # async with connect(...) as ...: ...
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    async def __aenter__(self) -> ClientConnection:
 | 
					 | 
				
			||||||
        return await self
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    async def __aexit__(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        exc_type: type[BaseException] | None,
 | 
					 | 
				
			||||||
        exc_value: BaseException | None,
 | 
					 | 
				
			||||||
        traceback: TracebackType | None,
 | 
					 | 
				
			||||||
    ) -> None:
 | 
					 | 
				
			||||||
        await self.connection.close()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # async for ... in connect(...):
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    async def __aiter__(self) -> AsyncIterator[ClientConnection]:
 | 
					 | 
				
			||||||
        delays: Generator[float, None, None] | None = None
 | 
					 | 
				
			||||||
        while True:
 | 
					 | 
				
			||||||
            try:
 | 
					 | 
				
			||||||
                async with self as protocol:
 | 
					 | 
				
			||||||
                    yield protocol
 | 
					 | 
				
			||||||
            except Exception as exc:
 | 
					 | 
				
			||||||
                # Determine whether the exception is retryable or fatal.
 | 
					 | 
				
			||||||
                # The API of process_exception is "return an exception or None";
 | 
					 | 
				
			||||||
                # "raise an exception" is also supported because it's a frequent
 | 
					 | 
				
			||||||
                # mistake. It isn't documented in order to keep the API simple.
 | 
					 | 
				
			||||||
                try:
 | 
					 | 
				
			||||||
                    new_exc = self.process_exception(exc)
 | 
					 | 
				
			||||||
                except Exception as raised_exc:
 | 
					 | 
				
			||||||
                    new_exc = raised_exc
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                # The connection failed with a fatal error.
 | 
					 | 
				
			||||||
                # Raise the exception and exit the loop.
 | 
					 | 
				
			||||||
                if new_exc is exc:
 | 
					 | 
				
			||||||
                    raise
 | 
					 | 
				
			||||||
                if new_exc is not None:
 | 
					 | 
				
			||||||
                    raise new_exc from exc
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                # The connection failed with a retryable error.
 | 
					 | 
				
			||||||
                # Start or continue backoff and reconnect.
 | 
					 | 
				
			||||||
                if delays is None:
 | 
					 | 
				
			||||||
                    delays = backoff()
 | 
					 | 
				
			||||||
                delay = next(delays)
 | 
					 | 
				
			||||||
                self.logger.info(
 | 
					 | 
				
			||||||
                    "! connect failed; reconnecting in %.1f seconds",
 | 
					 | 
				
			||||||
                    delay,
 | 
					 | 
				
			||||||
                    exc_info=True,
 | 
					 | 
				
			||||||
                )
 | 
					 | 
				
			||||||
                await asyncio.sleep(delay)
 | 
					 | 
				
			||||||
                continue
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                # The connection succeeded. Reset backoff.
 | 
					 | 
				
			||||||
                delays = None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def unix_connect(
 | 
					 | 
				
			||||||
    path: str | None = None,
 | 
					 | 
				
			||||||
    uri: str | None = None,
 | 
					 | 
				
			||||||
    **kwargs: Any,
 | 
					 | 
				
			||||||
) -> connect:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Connect to a WebSocket server listening on a Unix socket.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    This function accepts the same keyword arguments as :func:`connect`.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    It's only available on Unix.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    It's mainly useful for debugging servers listening on Unix sockets.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Args:
 | 
					 | 
				
			||||||
        path: File system path to the Unix socket.
 | 
					 | 
				
			||||||
        uri: URI of the WebSocket server. ``uri`` defaults to
 | 
					 | 
				
			||||||
            ``ws://localhost/`` or, when a ``ssl`` argument is provided, to
 | 
					 | 
				
			||||||
            ``wss://localhost/``.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    if uri is None:
 | 
					 | 
				
			||||||
        if kwargs.get("ssl") is None:
 | 
					 | 
				
			||||||
            uri = "ws://localhost/"
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            uri = "wss://localhost/"
 | 
					 | 
				
			||||||
    return connect(uri=uri, unix=True, path=path, **kwargs)
 | 
					 | 
				
			||||||
@@ -1,30 +0,0 @@
 | 
				
			|||||||
from __future__ import annotations
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import sys
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
__all__ = ["TimeoutError", "aiter", "anext", "asyncio_timeout", "asyncio_timeout_at"]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
if sys.version_info[:2] >= (3, 11):
 | 
					 | 
				
			||||||
    TimeoutError = TimeoutError
 | 
					 | 
				
			||||||
    aiter = aiter
 | 
					 | 
				
			||||||
    anext = anext
 | 
					 | 
				
			||||||
    from asyncio import (
 | 
					 | 
				
			||||||
        timeout as asyncio_timeout,  # noqa: F401
 | 
					 | 
				
			||||||
        timeout_at as asyncio_timeout_at,  # noqa: F401
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
else:  # Python < 3.11
 | 
					 | 
				
			||||||
    from asyncio import TimeoutError
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def aiter(async_iterable):
 | 
					 | 
				
			||||||
        return type(async_iterable).__aiter__(async_iterable)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    async def anext(async_iterator):
 | 
					 | 
				
			||||||
        return await type(async_iterator).__anext__(async_iterator)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    from .async_timeout import (
 | 
					 | 
				
			||||||
        timeout as asyncio_timeout,  # noqa: F401
 | 
					 | 
				
			||||||
        timeout_at as asyncio_timeout_at,  # noqa: F401
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							@@ -1,293 +0,0 @@
 | 
				
			|||||||
from __future__ import annotations
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import asyncio
 | 
					 | 
				
			||||||
import codecs
 | 
					 | 
				
			||||||
import collections
 | 
					 | 
				
			||||||
from typing import (
 | 
					 | 
				
			||||||
    Any,
 | 
					 | 
				
			||||||
    AsyncIterator,
 | 
					 | 
				
			||||||
    Callable,
 | 
					 | 
				
			||||||
    Generic,
 | 
					 | 
				
			||||||
    Iterable,
 | 
					 | 
				
			||||||
    TypeVar,
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from ..exceptions import ConcurrencyError
 | 
					 | 
				
			||||||
from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame
 | 
					 | 
				
			||||||
from ..typing import Data
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
__all__ = ["Assembler"]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
UTF8Decoder = codecs.getincrementaldecoder("utf-8")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
T = TypeVar("T")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class SimpleQueue(Generic[T]):
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Simplified version of :class:`asyncio.Queue`.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Provides only the subset of functionality needed by :class:`Assembler`.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(self) -> None:
 | 
					 | 
				
			||||||
        self.loop = asyncio.get_running_loop()
 | 
					 | 
				
			||||||
        self.get_waiter: asyncio.Future[None] | None = None
 | 
					 | 
				
			||||||
        self.queue: collections.deque[T] = collections.deque()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __len__(self) -> int:
 | 
					 | 
				
			||||||
        return len(self.queue)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def put(self, item: T) -> None:
 | 
					 | 
				
			||||||
        """Put an item into the queue without waiting."""
 | 
					 | 
				
			||||||
        self.queue.append(item)
 | 
					 | 
				
			||||||
        if self.get_waiter is not None and not self.get_waiter.done():
 | 
					 | 
				
			||||||
            self.get_waiter.set_result(None)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    async def get(self) -> T:
 | 
					 | 
				
			||||||
        """Remove and return an item from the queue, waiting if necessary."""
 | 
					 | 
				
			||||||
        if not self.queue:
 | 
					 | 
				
			||||||
            if self.get_waiter is not None:
 | 
					 | 
				
			||||||
                raise ConcurrencyError("get is already running")
 | 
					 | 
				
			||||||
            self.get_waiter = self.loop.create_future()
 | 
					 | 
				
			||||||
            try:
 | 
					 | 
				
			||||||
                await self.get_waiter
 | 
					 | 
				
			||||||
            finally:
 | 
					 | 
				
			||||||
                self.get_waiter.cancel()
 | 
					 | 
				
			||||||
                self.get_waiter = None
 | 
					 | 
				
			||||||
        return self.queue.popleft()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def reset(self, items: Iterable[T]) -> None:
 | 
					 | 
				
			||||||
        """Put back items into an empty, idle queue."""
 | 
					 | 
				
			||||||
        assert self.get_waiter is None, "cannot reset() while get() is running"
 | 
					 | 
				
			||||||
        assert not self.queue, "cannot reset() while queue isn't empty"
 | 
					 | 
				
			||||||
        self.queue.extend(items)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def abort(self) -> None:
 | 
					 | 
				
			||||||
        if self.get_waiter is not None and not self.get_waiter.done():
 | 
					 | 
				
			||||||
            self.get_waiter.set_exception(EOFError("stream of frames ended"))
 | 
					 | 
				
			||||||
        # Clear the queue to avoid storing unnecessary data in memory.
 | 
					 | 
				
			||||||
        self.queue.clear()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class Assembler:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Assemble messages from frames.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    :class:`Assembler` expects only data frames. The stream of frames must
 | 
					 | 
				
			||||||
    respect the protocol; if it doesn't, the behavior is undefined.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Args:
 | 
					 | 
				
			||||||
        pause: Called when the buffer of frames goes above the high water mark;
 | 
					 | 
				
			||||||
            should pause reading from the network.
 | 
					 | 
				
			||||||
        resume: Called when the buffer of frames goes below the low water mark;
 | 
					 | 
				
			||||||
            should resume reading from the network.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # coverage reports incorrectly: "line NN didn't jump to the function exit"
 | 
					 | 
				
			||||||
    def __init__(  # pragma: no cover
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        high: int = 16,
 | 
					 | 
				
			||||||
        low: int | None = None,
 | 
					 | 
				
			||||||
        pause: Callable[[], Any] = lambda: None,
 | 
					 | 
				
			||||||
        resume: Callable[[], Any] = lambda: None,
 | 
					 | 
				
			||||||
    ) -> None:
 | 
					 | 
				
			||||||
        # Queue of incoming messages. Each item is a queue of frames.
 | 
					 | 
				
			||||||
        self.frames: SimpleQueue[Frame] = SimpleQueue()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # We cannot put a hard limit on the size of the queue because a single
 | 
					 | 
				
			||||||
        # call to Protocol.data_received() could produce thousands of frames,
 | 
					 | 
				
			||||||
        # which must be buffered. Instead, we pause reading when the buffer goes
 | 
					 | 
				
			||||||
        # above the high limit and we resume when it goes under the low limit.
 | 
					 | 
				
			||||||
        if low is None:
 | 
					 | 
				
			||||||
            low = high // 4
 | 
					 | 
				
			||||||
        if low < 0:
 | 
					 | 
				
			||||||
            raise ValueError("low must be positive or equal to zero")
 | 
					 | 
				
			||||||
        if high < low:
 | 
					 | 
				
			||||||
            raise ValueError("high must be greater than or equal to low")
 | 
					 | 
				
			||||||
        self.high, self.low = high, low
 | 
					 | 
				
			||||||
        self.pause = pause
 | 
					 | 
				
			||||||
        self.resume = resume
 | 
					 | 
				
			||||||
        self.paused = False
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # This flag prevents concurrent calls to get() by user code.
 | 
					 | 
				
			||||||
        self.get_in_progress = False
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # This flag marks the end of the connection.
 | 
					 | 
				
			||||||
        self.closed = False
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    async def get(self, decode: bool | None = None) -> Data:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Read the next message.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        :meth:`get` returns a single :class:`str` or :class:`bytes`.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        If the message is fragmented, :meth:`get` waits until the last frame is
 | 
					 | 
				
			||||||
        received, then it reassembles the message and returns it. To receive
 | 
					 | 
				
			||||||
        messages frame by frame, use :meth:`get_iter` instead.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Args:
 | 
					 | 
				
			||||||
            decode: :obj:`False` disables UTF-8 decoding of text frames and
 | 
					 | 
				
			||||||
                returns :class:`bytes`. :obj:`True` forces UTF-8 decoding of
 | 
					 | 
				
			||||||
                binary frames and returns :class:`str`.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Raises:
 | 
					 | 
				
			||||||
            EOFError: If the stream of frames has ended.
 | 
					 | 
				
			||||||
            ConcurrencyError: If two coroutines run :meth:`get` or
 | 
					 | 
				
			||||||
                :meth:`get_iter` concurrently.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        if self.closed:
 | 
					 | 
				
			||||||
            raise EOFError("stream of frames ended")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if self.get_in_progress:
 | 
					 | 
				
			||||||
            raise ConcurrencyError("get() or get_iter() is already running")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Locking with get_in_progress ensures only one coroutine can get here.
 | 
					 | 
				
			||||||
        self.get_in_progress = True
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # First frame
 | 
					 | 
				
			||||||
        try:
 | 
					 | 
				
			||||||
            frame = await self.frames.get()
 | 
					 | 
				
			||||||
        except asyncio.CancelledError:
 | 
					 | 
				
			||||||
            self.get_in_progress = False
 | 
					 | 
				
			||||||
            raise
 | 
					 | 
				
			||||||
        self.maybe_resume()
 | 
					 | 
				
			||||||
        assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY
 | 
					 | 
				
			||||||
        if decode is None:
 | 
					 | 
				
			||||||
            decode = frame.opcode is OP_TEXT
 | 
					 | 
				
			||||||
        frames = [frame]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Following frames, for fragmented messages
 | 
					 | 
				
			||||||
        while not frame.fin:
 | 
					 | 
				
			||||||
            try:
 | 
					 | 
				
			||||||
                frame = await self.frames.get()
 | 
					 | 
				
			||||||
            except asyncio.CancelledError:
 | 
					 | 
				
			||||||
                # Put frames already received back into the queue
 | 
					 | 
				
			||||||
                # so that future calls to get() can return them.
 | 
					 | 
				
			||||||
                self.frames.reset(frames)
 | 
					 | 
				
			||||||
                self.get_in_progress = False
 | 
					 | 
				
			||||||
                raise
 | 
					 | 
				
			||||||
            self.maybe_resume()
 | 
					 | 
				
			||||||
            assert frame.opcode is OP_CONT
 | 
					 | 
				
			||||||
            frames.append(frame)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.get_in_progress = False
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        data = b"".join(frame.data for frame in frames)
 | 
					 | 
				
			||||||
        if decode:
 | 
					 | 
				
			||||||
            return data.decode()
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            return data
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Stream the next message.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Iterating the return value of :meth:`get_iter` asynchronously yields a
 | 
					 | 
				
			||||||
        :class:`str` or :class:`bytes` for each frame in the message.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        The iterator must be fully consumed before calling :meth:`get_iter` or
 | 
					 | 
				
			||||||
        :meth:`get` again. Else, :exc:`ConcurrencyError` is raised.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        This method only makes sense for fragmented messages. If messages aren't
 | 
					 | 
				
			||||||
        fragmented, use :meth:`get` instead.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Args:
 | 
					 | 
				
			||||||
            decode: :obj:`False` disables UTF-8 decoding of text frames and
 | 
					 | 
				
			||||||
                returns :class:`bytes`. :obj:`True` forces UTF-8 decoding of
 | 
					 | 
				
			||||||
                binary frames and returns :class:`str`.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Raises:
 | 
					 | 
				
			||||||
            EOFError: If the stream of frames has ended.
 | 
					 | 
				
			||||||
            ConcurrencyError: If two coroutines run :meth:`get` or
 | 
					 | 
				
			||||||
                :meth:`get_iter` concurrently.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        if self.closed:
 | 
					 | 
				
			||||||
            raise EOFError("stream of frames ended")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if self.get_in_progress:
 | 
					 | 
				
			||||||
            raise ConcurrencyError("get() or get_iter() is already running")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Locking with get_in_progress ensures only one coroutine can get here.
 | 
					 | 
				
			||||||
        self.get_in_progress = True
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # First frame
 | 
					 | 
				
			||||||
        try:
 | 
					 | 
				
			||||||
            frame = await self.frames.get()
 | 
					 | 
				
			||||||
        except asyncio.CancelledError:
 | 
					 | 
				
			||||||
            self.get_in_progress = False
 | 
					 | 
				
			||||||
            raise
 | 
					 | 
				
			||||||
        self.maybe_resume()
 | 
					 | 
				
			||||||
        assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY
 | 
					 | 
				
			||||||
        if decode is None:
 | 
					 | 
				
			||||||
            decode = frame.opcode is OP_TEXT
 | 
					 | 
				
			||||||
        if decode:
 | 
					 | 
				
			||||||
            decoder = UTF8Decoder()
 | 
					 | 
				
			||||||
            yield decoder.decode(frame.data, frame.fin)
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            yield frame.data
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Following frames, for fragmented messages
 | 
					 | 
				
			||||||
        while not frame.fin:
 | 
					 | 
				
			||||||
            # We cannot handle asyncio.CancelledError because we don't buffer
 | 
					 | 
				
			||||||
            # previous fragments — we're streaming them. Canceling get_iter()
 | 
					 | 
				
			||||||
            # here will leave the assembler in a stuck state. Future calls to
 | 
					 | 
				
			||||||
            # get() or get_iter() will raise ConcurrencyError.
 | 
					 | 
				
			||||||
            frame = await self.frames.get()
 | 
					 | 
				
			||||||
            self.maybe_resume()
 | 
					 | 
				
			||||||
            assert frame.opcode is OP_CONT
 | 
					 | 
				
			||||||
            if decode:
 | 
					 | 
				
			||||||
                yield decoder.decode(frame.data, frame.fin)
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                yield frame.data
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.get_in_progress = False
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def put(self, frame: Frame) -> None:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Add ``frame`` to the next message.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Raises:
 | 
					 | 
				
			||||||
            EOFError: If the stream of frames has ended.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        if self.closed:
 | 
					 | 
				
			||||||
            raise EOFError("stream of frames ended")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.frames.put(frame)
 | 
					 | 
				
			||||||
        self.maybe_pause()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def maybe_pause(self) -> None:
 | 
					 | 
				
			||||||
        """Pause the writer if queue is above the high water mark."""
 | 
					 | 
				
			||||||
        # Check for "> high" to support high = 0
 | 
					 | 
				
			||||||
        if len(self.frames) > self.high and not self.paused:
 | 
					 | 
				
			||||||
            self.paused = True
 | 
					 | 
				
			||||||
            self.pause()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def maybe_resume(self) -> None:
 | 
					 | 
				
			||||||
        """Resume the writer if queue is below the low water mark."""
 | 
					 | 
				
			||||||
        # Check for "<= low" to support low = 0
 | 
					 | 
				
			||||||
        if len(self.frames) <= self.low and self.paused:
 | 
					 | 
				
			||||||
            self.paused = False
 | 
					 | 
				
			||||||
            self.resume()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def close(self) -> None:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        End the stream of frames.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Callling :meth:`close` concurrently with :meth:`get`, :meth:`get_iter`,
 | 
					 | 
				
			||||||
        or :meth:`put` is safe. They will raise :exc:`EOFError`.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        if self.closed:
 | 
					 | 
				
			||||||
            return
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.closed = True
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Unblock get() or get_iter().
 | 
					 | 
				
			||||||
        self.frames.abort()
 | 
					 | 
				
			||||||
@@ -1,973 +0,0 @@
 | 
				
			|||||||
from __future__ import annotations
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import asyncio
 | 
					 | 
				
			||||||
import hmac
 | 
					 | 
				
			||||||
import http
 | 
					 | 
				
			||||||
import logging
 | 
					 | 
				
			||||||
import socket
 | 
					 | 
				
			||||||
import sys
 | 
					 | 
				
			||||||
from types import TracebackType
 | 
					 | 
				
			||||||
from typing import (
 | 
					 | 
				
			||||||
    Any,
 | 
					 | 
				
			||||||
    Awaitable,
 | 
					 | 
				
			||||||
    Callable,
 | 
					 | 
				
			||||||
    Generator,
 | 
					 | 
				
			||||||
    Iterable,
 | 
					 | 
				
			||||||
    Sequence,
 | 
					 | 
				
			||||||
    Tuple,
 | 
					 | 
				
			||||||
    cast,
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from ..exceptions import InvalidHeader
 | 
					 | 
				
			||||||
from ..extensions.base import ServerExtensionFactory
 | 
					 | 
				
			||||||
from ..extensions.permessage_deflate import enable_server_permessage_deflate
 | 
					 | 
				
			||||||
from ..frames import CloseCode
 | 
					 | 
				
			||||||
from ..headers import (
 | 
					 | 
				
			||||||
    build_www_authenticate_basic,
 | 
					 | 
				
			||||||
    parse_authorization_basic,
 | 
					 | 
				
			||||||
    validate_subprotocols,
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
from ..http11 import SERVER, Request, Response
 | 
					 | 
				
			||||||
from ..protocol import CONNECTING, OPEN, Event
 | 
					 | 
				
			||||||
from ..server import ServerProtocol
 | 
					 | 
				
			||||||
from ..typing import LoggerLike, Origin, StatusLike, Subprotocol
 | 
					 | 
				
			||||||
from .compatibility import asyncio_timeout
 | 
					 | 
				
			||||||
from .connection import Connection, broadcast
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
__all__ = [
 | 
					 | 
				
			||||||
    "broadcast",
 | 
					 | 
				
			||||||
    "serve",
 | 
					 | 
				
			||||||
    "unix_serve",
 | 
					 | 
				
			||||||
    "ServerConnection",
 | 
					 | 
				
			||||||
    "Server",
 | 
					 | 
				
			||||||
    "basic_auth",
 | 
					 | 
				
			||||||
]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class ServerConnection(Connection):
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    :mod:`asyncio` implementation of a WebSocket server connection.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    :class:`ServerConnection` provides :meth:`recv` and :meth:`send` methods for
 | 
					 | 
				
			||||||
    receiving and sending messages.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    It supports asynchronous iteration to receive messages::
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        async for message in websocket:
 | 
					 | 
				
			||||||
            await process(message)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    The iterator exits normally when the connection is closed with close code
 | 
					 | 
				
			||||||
    1000 (OK) or 1001 (going away) or without a close code. It raises a
 | 
					 | 
				
			||||||
    :exc:`~websockets.exceptions.ConnectionClosedError` when the connection is
 | 
					 | 
				
			||||||
    closed with any other code.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    The ``ping_interval``, ``ping_timeout``, ``close_timeout``, ``max_queue``,
 | 
					 | 
				
			||||||
    and ``write_limit`` arguments the same meaning as in :func:`serve`.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Args:
 | 
					 | 
				
			||||||
        protocol: Sans-I/O connection.
 | 
					 | 
				
			||||||
        server: Server that manages this connection.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        protocol: ServerProtocol,
 | 
					 | 
				
			||||||
        server: Server,
 | 
					 | 
				
			||||||
        *,
 | 
					 | 
				
			||||||
        ping_interval: float | None = 20,
 | 
					 | 
				
			||||||
        ping_timeout: float | None = 20,
 | 
					 | 
				
			||||||
        close_timeout: float | None = 10,
 | 
					 | 
				
			||||||
        max_queue: int | tuple[int, int | None] = 16,
 | 
					 | 
				
			||||||
        write_limit: int | tuple[int, int | None] = 2**15,
 | 
					 | 
				
			||||||
    ) -> None:
 | 
					 | 
				
			||||||
        self.protocol: ServerProtocol
 | 
					 | 
				
			||||||
        super().__init__(
 | 
					 | 
				
			||||||
            protocol,
 | 
					 | 
				
			||||||
            ping_interval=ping_interval,
 | 
					 | 
				
			||||||
            ping_timeout=ping_timeout,
 | 
					 | 
				
			||||||
            close_timeout=close_timeout,
 | 
					 | 
				
			||||||
            max_queue=max_queue,
 | 
					 | 
				
			||||||
            write_limit=write_limit,
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        self.server = server
 | 
					 | 
				
			||||||
        self.request_rcvd: asyncio.Future[None] = self.loop.create_future()
 | 
					 | 
				
			||||||
        self.username: str  # see basic_auth()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def respond(self, status: StatusLike, text: str) -> Response:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Create a plain text HTTP response.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        ``process_request`` and ``process_response`` may call this method to
 | 
					 | 
				
			||||||
        return an HTTP response instead of performing the WebSocket opening
 | 
					 | 
				
			||||||
        handshake.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        You can modify the response before returning it, for example by changing
 | 
					 | 
				
			||||||
        HTTP headers.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Args:
 | 
					 | 
				
			||||||
            status: HTTP status code.
 | 
					 | 
				
			||||||
            text: HTTP response body; it will be encoded to UTF-8.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Returns:
 | 
					 | 
				
			||||||
            HTTP response to send to the client.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        return self.protocol.reject(status, text)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    async def handshake(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        process_request: (
 | 
					 | 
				
			||||||
            Callable[
 | 
					 | 
				
			||||||
                [ServerConnection, Request],
 | 
					 | 
				
			||||||
                Awaitable[Response | None] | Response | None,
 | 
					 | 
				
			||||||
            ]
 | 
					 | 
				
			||||||
            | None
 | 
					 | 
				
			||||||
        ) = None,
 | 
					 | 
				
			||||||
        process_response: (
 | 
					 | 
				
			||||||
            Callable[
 | 
					 | 
				
			||||||
                [ServerConnection, Request, Response],
 | 
					 | 
				
			||||||
                Awaitable[Response | None] | Response | None,
 | 
					 | 
				
			||||||
            ]
 | 
					 | 
				
			||||||
            | None
 | 
					 | 
				
			||||||
        ) = None,
 | 
					 | 
				
			||||||
        server_header: str | None = SERVER,
 | 
					 | 
				
			||||||
    ) -> None:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Perform the opening handshake.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        await asyncio.wait(
 | 
					 | 
				
			||||||
            [self.request_rcvd, self.connection_lost_waiter],
 | 
					 | 
				
			||||||
            return_when=asyncio.FIRST_COMPLETED,
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if self.request is not None:
 | 
					 | 
				
			||||||
            async with self.send_context(expected_state=CONNECTING):
 | 
					 | 
				
			||||||
                response = None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                if process_request is not None:
 | 
					 | 
				
			||||||
                    try:
 | 
					 | 
				
			||||||
                        response = process_request(self, self.request)
 | 
					 | 
				
			||||||
                        if isinstance(response, Awaitable):
 | 
					 | 
				
			||||||
                            response = await response
 | 
					 | 
				
			||||||
                    except Exception as exc:
 | 
					 | 
				
			||||||
                        self.protocol.handshake_exc = exc
 | 
					 | 
				
			||||||
                        response = self.protocol.reject(
 | 
					 | 
				
			||||||
                            http.HTTPStatus.INTERNAL_SERVER_ERROR,
 | 
					 | 
				
			||||||
                            (
 | 
					 | 
				
			||||||
                                "Failed to open a WebSocket connection.\n"
 | 
					 | 
				
			||||||
                                "See server log for more information.\n"
 | 
					 | 
				
			||||||
                            ),
 | 
					 | 
				
			||||||
                        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                if response is None:
 | 
					 | 
				
			||||||
                    if self.server.is_serving():
 | 
					 | 
				
			||||||
                        self.response = self.protocol.accept(self.request)
 | 
					 | 
				
			||||||
                    else:
 | 
					 | 
				
			||||||
                        self.response = self.protocol.reject(
 | 
					 | 
				
			||||||
                            http.HTTPStatus.SERVICE_UNAVAILABLE,
 | 
					 | 
				
			||||||
                            "Server is shutting down.\n",
 | 
					 | 
				
			||||||
                        )
 | 
					 | 
				
			||||||
                else:
 | 
					 | 
				
			||||||
                    assert isinstance(response, Response)  # help mypy
 | 
					 | 
				
			||||||
                    self.response = response
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                if server_header:
 | 
					 | 
				
			||||||
                    self.response.headers["Server"] = server_header
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                response = None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                if process_response is not None:
 | 
					 | 
				
			||||||
                    try:
 | 
					 | 
				
			||||||
                        response = process_response(self, self.request, self.response)
 | 
					 | 
				
			||||||
                        if isinstance(response, Awaitable):
 | 
					 | 
				
			||||||
                            response = await response
 | 
					 | 
				
			||||||
                    except Exception as exc:
 | 
					 | 
				
			||||||
                        self.protocol.handshake_exc = exc
 | 
					 | 
				
			||||||
                        response = self.protocol.reject(
 | 
					 | 
				
			||||||
                            http.HTTPStatus.INTERNAL_SERVER_ERROR,
 | 
					 | 
				
			||||||
                            (
 | 
					 | 
				
			||||||
                                "Failed to open a WebSocket connection.\n"
 | 
					 | 
				
			||||||
                                "See server log for more information.\n"
 | 
					 | 
				
			||||||
                            ),
 | 
					 | 
				
			||||||
                        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                if response is not None:
 | 
					 | 
				
			||||||
                    assert isinstance(response, Response)  # help mypy
 | 
					 | 
				
			||||||
                    self.response = response
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                self.protocol.send_response(self.response)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # self.protocol.handshake_exc is always set when the connection is lost
 | 
					 | 
				
			||||||
        # before receiving a request, when the request cannot be parsed, when
 | 
					 | 
				
			||||||
        # the handshake encounters an error, or when process_request or
 | 
					 | 
				
			||||||
        # process_response sends a HTTP response that rejects the handshake.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if self.protocol.handshake_exc is not None:
 | 
					 | 
				
			||||||
            raise self.protocol.handshake_exc
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def process_event(self, event: Event) -> None:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Process one incoming event.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        # First event - handshake request.
 | 
					 | 
				
			||||||
        if self.request is None:
 | 
					 | 
				
			||||||
            assert isinstance(event, Request)
 | 
					 | 
				
			||||||
            self.request = event
 | 
					 | 
				
			||||||
            self.request_rcvd.set_result(None)
 | 
					 | 
				
			||||||
        # Later events - frames.
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            super().process_event(event)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def connection_made(self, transport: asyncio.BaseTransport) -> None:
 | 
					 | 
				
			||||||
        super().connection_made(transport)
 | 
					 | 
				
			||||||
        self.server.start_connection_handler(self)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class Server:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    WebSocket server returned by :func:`serve`.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    This class mirrors the API of :class:`asyncio.Server`.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    It keeps track of WebSocket connections in order to close them properly
 | 
					 | 
				
			||||||
    when shutting down.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Args:
 | 
					 | 
				
			||||||
        handler: Connection handler. It receives the WebSocket connection,
 | 
					 | 
				
			||||||
            which is a :class:`ServerConnection`, in argument.
 | 
					 | 
				
			||||||
        process_request: Intercept the request during the opening handshake.
 | 
					 | 
				
			||||||
            Return an HTTP response to force the response. Return :obj:`None` to
 | 
					 | 
				
			||||||
            continue normally. When you force an HTTP 101 Continue response, the
 | 
					 | 
				
			||||||
            handshake is successful. Else, the connection is aborted.
 | 
					 | 
				
			||||||
            ``process_request`` may be a function or a coroutine.
 | 
					 | 
				
			||||||
        process_response: Intercept the response during the opening handshake.
 | 
					 | 
				
			||||||
            Modify the response or return a new HTTP response to force the
 | 
					 | 
				
			||||||
            response. Return :obj:`None` to continue normally. When you force an
 | 
					 | 
				
			||||||
            HTTP 101 Continue response, the handshake is successful. Else, the
 | 
					 | 
				
			||||||
            connection is aborted. ``process_response`` may be a function or a
 | 
					 | 
				
			||||||
            coroutine.
 | 
					 | 
				
			||||||
        server_header: Value of  the ``Server`` response header.
 | 
					 | 
				
			||||||
            It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to
 | 
					 | 
				
			||||||
            :obj:`None` removes the header.
 | 
					 | 
				
			||||||
        open_timeout: Timeout for opening connections in seconds.
 | 
					 | 
				
			||||||
            :obj:`None` disables the timeout.
 | 
					 | 
				
			||||||
        logger: Logger for this server.
 | 
					 | 
				
			||||||
            It defaults to ``logging.getLogger("websockets.server")``.
 | 
					 | 
				
			||||||
            See the :doc:`logging guide <../../topics/logging>` for details.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        handler: Callable[[ServerConnection], Awaitable[None]],
 | 
					 | 
				
			||||||
        *,
 | 
					 | 
				
			||||||
        process_request: (
 | 
					 | 
				
			||||||
            Callable[
 | 
					 | 
				
			||||||
                [ServerConnection, Request],
 | 
					 | 
				
			||||||
                Awaitable[Response | None] | Response | None,
 | 
					 | 
				
			||||||
            ]
 | 
					 | 
				
			||||||
            | None
 | 
					 | 
				
			||||||
        ) = None,
 | 
					 | 
				
			||||||
        process_response: (
 | 
					 | 
				
			||||||
            Callable[
 | 
					 | 
				
			||||||
                [ServerConnection, Request, Response],
 | 
					 | 
				
			||||||
                Awaitable[Response | None] | Response | None,
 | 
					 | 
				
			||||||
            ]
 | 
					 | 
				
			||||||
            | None
 | 
					 | 
				
			||||||
        ) = None,
 | 
					 | 
				
			||||||
        server_header: str | None = SERVER,
 | 
					 | 
				
			||||||
        open_timeout: float | None = 10,
 | 
					 | 
				
			||||||
        logger: LoggerLike | None = None,
 | 
					 | 
				
			||||||
    ) -> None:
 | 
					 | 
				
			||||||
        self.loop = asyncio.get_running_loop()
 | 
					 | 
				
			||||||
        self.handler = handler
 | 
					 | 
				
			||||||
        self.process_request = process_request
 | 
					 | 
				
			||||||
        self.process_response = process_response
 | 
					 | 
				
			||||||
        self.server_header = server_header
 | 
					 | 
				
			||||||
        self.open_timeout = open_timeout
 | 
					 | 
				
			||||||
        if logger is None:
 | 
					 | 
				
			||||||
            logger = logging.getLogger("websockets.server")
 | 
					 | 
				
			||||||
        self.logger = logger
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Keep track of active connections.
 | 
					 | 
				
			||||||
        self.handlers: dict[ServerConnection, asyncio.Task[None]] = {}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Task responsible for closing the server and terminating connections.
 | 
					 | 
				
			||||||
        self.close_task: asyncio.Task[None] | None = None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Completed when the server is closed and connections are terminated.
 | 
					 | 
				
			||||||
        self.closed_waiter: asyncio.Future[None] = self.loop.create_future()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @property
 | 
					 | 
				
			||||||
    def connections(self) -> set[ServerConnection]:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Set of active connections.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        This property contains all connections that completed the opening
 | 
					 | 
				
			||||||
        handshake successfully and didn't start the closing handshake yet.
 | 
					 | 
				
			||||||
        It can be useful in combination with :func:`~broadcast`.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        return {connection for connection in self.handlers if connection.state is OPEN}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def wrap(self, server: asyncio.Server) -> None:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Attach to a given :class:`asyncio.Server`.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Since :meth:`~asyncio.loop.create_server` doesn't support injecting a
 | 
					 | 
				
			||||||
        custom ``Server`` class, the easiest solution that doesn't rely on
 | 
					 | 
				
			||||||
        private :mod:`asyncio` APIs is to:
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        - instantiate a :class:`Server`
 | 
					 | 
				
			||||||
        - give the protocol factory a reference to that instance
 | 
					 | 
				
			||||||
        - call :meth:`~asyncio.loop.create_server` with the factory
 | 
					 | 
				
			||||||
        - attach the resulting :class:`asyncio.Server` with this method
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        self.server = server
 | 
					 | 
				
			||||||
        for sock in server.sockets:
 | 
					 | 
				
			||||||
            if sock.family == socket.AF_INET:
 | 
					 | 
				
			||||||
                name = "%s:%d" % sock.getsockname()
 | 
					 | 
				
			||||||
            elif sock.family == socket.AF_INET6:
 | 
					 | 
				
			||||||
                name = "[%s]:%d" % sock.getsockname()[:2]
 | 
					 | 
				
			||||||
            elif sock.family == socket.AF_UNIX:
 | 
					 | 
				
			||||||
                name = sock.getsockname()
 | 
					 | 
				
			||||||
            # In the unlikely event that someone runs websockets over a
 | 
					 | 
				
			||||||
            # protocol other than IP or Unix sockets, avoid crashing.
 | 
					 | 
				
			||||||
            else:  # pragma: no cover
 | 
					 | 
				
			||||||
                name = str(sock.getsockname())
 | 
					 | 
				
			||||||
            self.logger.info("server listening on %s", name)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    async def conn_handler(self, connection: ServerConnection) -> None:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Handle the lifecycle of a WebSocket connection.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Since this method doesn't have a caller that can handle exceptions,
 | 
					 | 
				
			||||||
        it attempts to log relevant ones.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        It guarantees that the TCP connection is closed before exiting.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        try:
 | 
					 | 
				
			||||||
            async with asyncio_timeout(self.open_timeout):
 | 
					 | 
				
			||||||
                try:
 | 
					 | 
				
			||||||
                    await connection.handshake(
 | 
					 | 
				
			||||||
                        self.process_request,
 | 
					 | 
				
			||||||
                        self.process_response,
 | 
					 | 
				
			||||||
                        self.server_header,
 | 
					 | 
				
			||||||
                    )
 | 
					 | 
				
			||||||
                except asyncio.CancelledError:
 | 
					 | 
				
			||||||
                    connection.close_transport()
 | 
					 | 
				
			||||||
                    raise
 | 
					 | 
				
			||||||
                except Exception:
 | 
					 | 
				
			||||||
                    connection.logger.error("opening handshake failed", exc_info=True)
 | 
					 | 
				
			||||||
                    connection.close_transport()
 | 
					 | 
				
			||||||
                    return
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            assert connection.protocol.state is OPEN
 | 
					 | 
				
			||||||
            try:
 | 
					 | 
				
			||||||
                connection.start_keepalive()
 | 
					 | 
				
			||||||
                await self.handler(connection)
 | 
					 | 
				
			||||||
            except Exception:
 | 
					 | 
				
			||||||
                connection.logger.error("connection handler failed", exc_info=True)
 | 
					 | 
				
			||||||
                await connection.close(CloseCode.INTERNAL_ERROR)
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                await connection.close()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        except TimeoutError:
 | 
					 | 
				
			||||||
            # When the opening handshake times out, there's nothing to log.
 | 
					 | 
				
			||||||
            pass
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        except Exception:  # pragma: no cover
 | 
					 | 
				
			||||||
            # Don't leak connections on unexpected errors.
 | 
					 | 
				
			||||||
            connection.transport.abort()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        finally:
 | 
					 | 
				
			||||||
            # Registration is tied to the lifecycle of conn_handler() because
 | 
					 | 
				
			||||||
            # the server waits for connection handlers to terminate, even if
 | 
					 | 
				
			||||||
            # all connections are already closed.
 | 
					 | 
				
			||||||
            del self.handlers[connection]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def start_connection_handler(self, connection: ServerConnection) -> None:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Register a connection with this server.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        # The connection must be registered in self.handlers immediately.
 | 
					 | 
				
			||||||
        # If it was registered in conn_handler(), a race condition could
 | 
					 | 
				
			||||||
        # happen when closing the server after scheduling conn_handler()
 | 
					 | 
				
			||||||
        # but before it starts executing.
 | 
					 | 
				
			||||||
        self.handlers[connection] = self.loop.create_task(self.conn_handler(connection))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def close(self, close_connections: bool = True) -> None:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Close the server.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        * Close the underlying :class:`asyncio.Server`.
 | 
					 | 
				
			||||||
        * When ``close_connections`` is :obj:`True`, which is the default,
 | 
					 | 
				
			||||||
          close existing connections. Specifically:
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
          * Reject opening WebSocket connections with an HTTP 503 (service
 | 
					 | 
				
			||||||
            unavailable) error. This happens when the server accepted the TCP
 | 
					 | 
				
			||||||
            connection but didn't complete the opening handshake before closing.
 | 
					 | 
				
			||||||
          * Close open WebSocket connections with close code 1001 (going away).
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        * Wait until all connection handlers terminate.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        :meth:`close` is idempotent.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        if self.close_task is None:
 | 
					 | 
				
			||||||
            self.close_task = self.get_loop().create_task(
 | 
					 | 
				
			||||||
                self._close(close_connections)
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    async def _close(self, close_connections: bool) -> None:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Implementation of :meth:`close`.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        This calls :meth:`~asyncio.Server.close` on the underlying
 | 
					 | 
				
			||||||
        :class:`asyncio.Server` object to stop accepting new connections and
 | 
					 | 
				
			||||||
        then closes open connections with close code 1001.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        self.logger.info("server closing")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Stop accepting new connections.
 | 
					 | 
				
			||||||
        self.server.close()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Wait until all accepted connections reach connection_made() and call
 | 
					 | 
				
			||||||
        # register(). See https://github.com/python/cpython/issues/79033 for
 | 
					 | 
				
			||||||
        # details. This workaround can be removed when dropping Python < 3.11.
 | 
					 | 
				
			||||||
        await asyncio.sleep(0)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if close_connections:
 | 
					 | 
				
			||||||
            # Close OPEN connections with close code 1001. After server.close(),
 | 
					 | 
				
			||||||
            # handshake() closes OPENING connections with an HTTP 503 error.
 | 
					 | 
				
			||||||
            close_tasks = [
 | 
					 | 
				
			||||||
                asyncio.create_task(connection.close(1001))
 | 
					 | 
				
			||||||
                for connection in self.handlers
 | 
					 | 
				
			||||||
                if connection.protocol.state is not CONNECTING
 | 
					 | 
				
			||||||
            ]
 | 
					 | 
				
			||||||
            # asyncio.wait doesn't accept an empty first argument.
 | 
					 | 
				
			||||||
            if close_tasks:
 | 
					 | 
				
			||||||
                await asyncio.wait(close_tasks)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Wait until all TCP connections are closed.
 | 
					 | 
				
			||||||
        await self.server.wait_closed()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Wait until all connection handlers terminate.
 | 
					 | 
				
			||||||
        # asyncio.wait doesn't accept an empty first argument.
 | 
					 | 
				
			||||||
        if self.handlers:
 | 
					 | 
				
			||||||
            await asyncio.wait(self.handlers.values())
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Tell wait_closed() to return.
 | 
					 | 
				
			||||||
        self.closed_waiter.set_result(None)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.logger.info("server closed")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    async def wait_closed(self) -> None:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Wait until the server is closed.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        When :meth:`wait_closed` returns, all TCP connections are closed and
 | 
					 | 
				
			||||||
        all connection handlers have returned.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        To ensure a fast shutdown, a connection handler should always be
 | 
					 | 
				
			||||||
        awaiting at least one of:
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        * :meth:`~ServerConnection.recv`: when the connection is closed,
 | 
					 | 
				
			||||||
          it raises :exc:`~websockets.exceptions.ConnectionClosedOK`;
 | 
					 | 
				
			||||||
        * :meth:`~ServerConnection.wait_closed`: when the connection is
 | 
					 | 
				
			||||||
          closed, it returns.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Then the connection handler is immediately notified of the shutdown;
 | 
					 | 
				
			||||||
        it can clean up and exit.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        await asyncio.shield(self.closed_waiter)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def get_loop(self) -> asyncio.AbstractEventLoop:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        See :meth:`asyncio.Server.get_loop`.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        return self.server.get_loop()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def is_serving(self) -> bool:  # pragma: no cover
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        See :meth:`asyncio.Server.is_serving`.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        return self.server.is_serving()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    async def start_serving(self) -> None:  # pragma: no cover
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        See :meth:`asyncio.Server.start_serving`.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Typical use::
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            server = await serve(..., start_serving=False)
 | 
					 | 
				
			||||||
            # perform additional setup here...
 | 
					 | 
				
			||||||
            # ... then start the server
 | 
					 | 
				
			||||||
            await server.start_serving()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        await self.server.start_serving()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    async def serve_forever(self) -> None:  # pragma: no cover
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        See :meth:`asyncio.Server.serve_forever`.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Typical use::
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            server = await serve(...)
 | 
					 | 
				
			||||||
            # this coroutine doesn't return
 | 
					 | 
				
			||||||
            # canceling it stops the server
 | 
					 | 
				
			||||||
            await server.serve_forever()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        This is an alternative to using :func:`serve` as an asynchronous context
 | 
					 | 
				
			||||||
        manager. Shutdown is triggered by canceling :meth:`serve_forever`
 | 
					 | 
				
			||||||
        instead of exiting a :func:`serve` context.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        await self.server.serve_forever()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @property
 | 
					 | 
				
			||||||
    def sockets(self) -> Iterable[socket.socket]:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        See :attr:`asyncio.Server.sockets`.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        return self.server.sockets
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    async def __aenter__(self) -> Server:  # pragma: no cover
 | 
					 | 
				
			||||||
        return self
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    async def __aexit__(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        exc_type: type[BaseException] | None,
 | 
					 | 
				
			||||||
        exc_value: BaseException | None,
 | 
					 | 
				
			||||||
        traceback: TracebackType | None,
 | 
					 | 
				
			||||||
    ) -> None:  # pragma: no cover
 | 
					 | 
				
			||||||
        self.close()
 | 
					 | 
				
			||||||
        await self.wait_closed()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# This is spelled in lower case because it's exposed as a callable in the API.
 | 
					 | 
				
			||||||
class serve:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Create a WebSocket server listening on ``host`` and ``port``.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Whenever a client connects, the server creates a :class:`ServerConnection`,
 | 
					 | 
				
			||||||
    performs the opening handshake, and delegates to the ``handler`` coroutine.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    The handler receives the :class:`ServerConnection` instance, which you can
 | 
					 | 
				
			||||||
    use to send and receive messages.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Once the handler completes, either normally or with an exception, the server
 | 
					 | 
				
			||||||
    performs the closing handshake and closes the connection.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    This coroutine returns a :class:`Server` whose API mirrors
 | 
					 | 
				
			||||||
    :class:`asyncio.Server`. Treat it as an asynchronous context manager to
 | 
					 | 
				
			||||||
    ensure that the server will be closed::
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        from websockets.asyncio.server import serve
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        def handler(websocket):
 | 
					 | 
				
			||||||
            ...
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # set this future to exit the server
 | 
					 | 
				
			||||||
        stop = asyncio.get_running_loop().create_future()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        async with serve(handler, host, port):
 | 
					 | 
				
			||||||
            await stop
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Alternatively, call :meth:`~Server.serve_forever` to serve requests and
 | 
					 | 
				
			||||||
    cancel it to stop the server::
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        server = await serve(handler, host, port)
 | 
					 | 
				
			||||||
        await server.serve_forever()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Args:
 | 
					 | 
				
			||||||
        handler: Connection handler. It receives the WebSocket connection,
 | 
					 | 
				
			||||||
            which is a :class:`ServerConnection`, in argument.
 | 
					 | 
				
			||||||
        host: Network interfaces the server binds to.
 | 
					 | 
				
			||||||
            See :meth:`~asyncio.loop.create_server` for details.
 | 
					 | 
				
			||||||
        port: TCP port the server listens on.
 | 
					 | 
				
			||||||
            See :meth:`~asyncio.loop.create_server` for details.
 | 
					 | 
				
			||||||
        origins: Acceptable values of the ``Origin`` header, for defending
 | 
					 | 
				
			||||||
            against Cross-Site WebSocket Hijacking attacks. Include :obj:`None`
 | 
					 | 
				
			||||||
            in the list if the lack of an origin is acceptable.
 | 
					 | 
				
			||||||
        extensions: List of supported extensions, in order in which they
 | 
					 | 
				
			||||||
            should be negotiated and run.
 | 
					 | 
				
			||||||
        subprotocols: List of supported subprotocols, in order of decreasing
 | 
					 | 
				
			||||||
            preference.
 | 
					 | 
				
			||||||
        select_subprotocol: Callback for selecting a subprotocol among
 | 
					 | 
				
			||||||
            those supported by the client and the server. It receives a
 | 
					 | 
				
			||||||
            :class:`ServerConnection` (not a
 | 
					 | 
				
			||||||
            :class:`~websockets.server.ServerProtocol`!) instance and a list of
 | 
					 | 
				
			||||||
            subprotocols offered by the client. Other than the first argument,
 | 
					 | 
				
			||||||
            it has the same behavior as the
 | 
					 | 
				
			||||||
            :meth:`ServerProtocol.select_subprotocol
 | 
					 | 
				
			||||||
            <websockets.server.ServerProtocol.select_subprotocol>` method.
 | 
					 | 
				
			||||||
        process_request: Intercept the request during the opening handshake.
 | 
					 | 
				
			||||||
            Return an HTTP response to force the response or :obj:`None` to
 | 
					 | 
				
			||||||
            continue normally. When you force an HTTP 101 Continue response, the
 | 
					 | 
				
			||||||
            handshake is successful. Else, the connection is aborted.
 | 
					 | 
				
			||||||
            ``process_request`` may be a function or a coroutine.
 | 
					 | 
				
			||||||
        process_response: Intercept the response during the opening handshake.
 | 
					 | 
				
			||||||
            Return an HTTP response to force the response or :obj:`None` to
 | 
					 | 
				
			||||||
            continue normally. When you force an HTTP 101 Continue response, the
 | 
					 | 
				
			||||||
            handshake is successful. Else, the connection is aborted.
 | 
					 | 
				
			||||||
            ``process_response`` may be a function or a coroutine.
 | 
					 | 
				
			||||||
        server_header: Value of  the ``Server`` response header.
 | 
					 | 
				
			||||||
            It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to
 | 
					 | 
				
			||||||
            :obj:`None` removes the header.
 | 
					 | 
				
			||||||
        compression: The "permessage-deflate" extension is enabled by default.
 | 
					 | 
				
			||||||
            Set ``compression`` to :obj:`None` to disable it. See the
 | 
					 | 
				
			||||||
            :doc:`compression guide <../../topics/compression>` for details.
 | 
					 | 
				
			||||||
        open_timeout: Timeout for opening connections in seconds.
 | 
					 | 
				
			||||||
            :obj:`None` disables the timeout.
 | 
					 | 
				
			||||||
        ping_interval: Interval between keepalive pings in seconds.
 | 
					 | 
				
			||||||
            :obj:`None` disables keepalive.
 | 
					 | 
				
			||||||
        ping_timeout: Timeout for keepalive pings in seconds.
 | 
					 | 
				
			||||||
            :obj:`None` disables timeouts.
 | 
					 | 
				
			||||||
        close_timeout: Timeout for closing connections in seconds.
 | 
					 | 
				
			||||||
            :obj:`None` disables the timeout.
 | 
					 | 
				
			||||||
        max_size: Maximum size of incoming messages in bytes.
 | 
					 | 
				
			||||||
            :obj:`None` disables the limit.
 | 
					 | 
				
			||||||
        max_queue: High-water mark of the buffer where frames are received.
 | 
					 | 
				
			||||||
            It defaults to 16 frames. The low-water mark defaults to ``max_queue
 | 
					 | 
				
			||||||
            // 4``. You may pass a ``(high, low)`` tuple to set the high-water
 | 
					 | 
				
			||||||
            and low-water marks.
 | 
					 | 
				
			||||||
        write_limit: High-water mark of write buffer in bytes. It is passed to
 | 
					 | 
				
			||||||
            :meth:`~asyncio.WriteTransport.set_write_buffer_limits`. It defaults
 | 
					 | 
				
			||||||
            to 32 KiB. You may pass a ``(high, low)`` tuple to set the
 | 
					 | 
				
			||||||
            high-water and low-water marks.
 | 
					 | 
				
			||||||
        logger: Logger for this server.
 | 
					 | 
				
			||||||
            It defaults to ``logging.getLogger("websockets.server")``. See the
 | 
					 | 
				
			||||||
            :doc:`logging guide <../../topics/logging>` for details.
 | 
					 | 
				
			||||||
        create_connection: Factory for the :class:`ServerConnection` managing
 | 
					 | 
				
			||||||
            the connection. Set it to a wrapper or a subclass to customize
 | 
					 | 
				
			||||||
            connection handling.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Any other keyword arguments are passed to the event loop's
 | 
					 | 
				
			||||||
    :meth:`~asyncio.loop.create_server` method.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    For example:
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    * You can set ``ssl`` to a :class:`~ssl.SSLContext` to enable TLS.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    * You can set ``sock`` to provide a preexisting TCP socket. You may call
 | 
					 | 
				
			||||||
      :func:`socket.create_server` (not to be confused with the event loop's
 | 
					 | 
				
			||||||
      :meth:`~asyncio.loop.create_server` method) to create a suitable server
 | 
					 | 
				
			||||||
      socket and customize it.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    * You can set ``start_serving`` to ``False`` to start accepting connections
 | 
					 | 
				
			||||||
      only after you call :meth:`~Server.start_serving()` or
 | 
					 | 
				
			||||||
      :meth:`~Server.serve_forever()`.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        handler: Callable[[ServerConnection], Awaitable[None]],
 | 
					 | 
				
			||||||
        host: str | None = None,
 | 
					 | 
				
			||||||
        port: int | None = None,
 | 
					 | 
				
			||||||
        *,
 | 
					 | 
				
			||||||
        # WebSocket
 | 
					 | 
				
			||||||
        origins: Sequence[Origin | None] | None = None,
 | 
					 | 
				
			||||||
        extensions: Sequence[ServerExtensionFactory] | None = None,
 | 
					 | 
				
			||||||
        subprotocols: Sequence[Subprotocol] | None = None,
 | 
					 | 
				
			||||||
        select_subprotocol: (
 | 
					 | 
				
			||||||
            Callable[
 | 
					 | 
				
			||||||
                [ServerConnection, Sequence[Subprotocol]],
 | 
					 | 
				
			||||||
                Subprotocol | None,
 | 
					 | 
				
			||||||
            ]
 | 
					 | 
				
			||||||
            | None
 | 
					 | 
				
			||||||
        ) = None,
 | 
					 | 
				
			||||||
        process_request: (
 | 
					 | 
				
			||||||
            Callable[
 | 
					 | 
				
			||||||
                [ServerConnection, Request],
 | 
					 | 
				
			||||||
                Awaitable[Response | None] | Response | None,
 | 
					 | 
				
			||||||
            ]
 | 
					 | 
				
			||||||
            | None
 | 
					 | 
				
			||||||
        ) = None,
 | 
					 | 
				
			||||||
        process_response: (
 | 
					 | 
				
			||||||
            Callable[
 | 
					 | 
				
			||||||
                [ServerConnection, Request, Response],
 | 
					 | 
				
			||||||
                Awaitable[Response | None] | Response | None,
 | 
					 | 
				
			||||||
            ]
 | 
					 | 
				
			||||||
            | None
 | 
					 | 
				
			||||||
        ) = None,
 | 
					 | 
				
			||||||
        server_header: str | None = SERVER,
 | 
					 | 
				
			||||||
        compression: str | None = "deflate",
 | 
					 | 
				
			||||||
        # Timeouts
 | 
					 | 
				
			||||||
        open_timeout: float | None = 10,
 | 
					 | 
				
			||||||
        ping_interval: float | None = 20,
 | 
					 | 
				
			||||||
        ping_timeout: float | None = 20,
 | 
					 | 
				
			||||||
        close_timeout: float | None = 10,
 | 
					 | 
				
			||||||
        # Limits
 | 
					 | 
				
			||||||
        max_size: int | None = 2**20,
 | 
					 | 
				
			||||||
        max_queue: int | tuple[int, int | None] = 16,
 | 
					 | 
				
			||||||
        write_limit: int | tuple[int, int | None] = 2**15,
 | 
					 | 
				
			||||||
        # Logging
 | 
					 | 
				
			||||||
        logger: LoggerLike | None = None,
 | 
					 | 
				
			||||||
        # Escape hatch for advanced customization
 | 
					 | 
				
			||||||
        create_connection: type[ServerConnection] | None = None,
 | 
					 | 
				
			||||||
        # Other keyword arguments are passed to loop.create_server
 | 
					 | 
				
			||||||
        **kwargs: Any,
 | 
					 | 
				
			||||||
    ) -> None:
 | 
					 | 
				
			||||||
        if subprotocols is not None:
 | 
					 | 
				
			||||||
            validate_subprotocols(subprotocols)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if compression == "deflate":
 | 
					 | 
				
			||||||
            extensions = enable_server_permessage_deflate(extensions)
 | 
					 | 
				
			||||||
        elif compression is not None:
 | 
					 | 
				
			||||||
            raise ValueError(f"unsupported compression: {compression}")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if create_connection is None:
 | 
					 | 
				
			||||||
            create_connection = ServerConnection
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.server = Server(
 | 
					 | 
				
			||||||
            handler,
 | 
					 | 
				
			||||||
            process_request=process_request,
 | 
					 | 
				
			||||||
            process_response=process_response,
 | 
					 | 
				
			||||||
            server_header=server_header,
 | 
					 | 
				
			||||||
            open_timeout=open_timeout,
 | 
					 | 
				
			||||||
            logger=logger,
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if kwargs.get("ssl") is not None:
 | 
					 | 
				
			||||||
            kwargs.setdefault("ssl_handshake_timeout", open_timeout)
 | 
					 | 
				
			||||||
            if sys.version_info[:2] >= (3, 11):  # pragma: no branch
 | 
					 | 
				
			||||||
                kwargs.setdefault("ssl_shutdown_timeout", close_timeout)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        def factory() -> ServerConnection:
 | 
					 | 
				
			||||||
            """
 | 
					 | 
				
			||||||
            Create an asyncio protocol for managing a WebSocket connection.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            """
 | 
					 | 
				
			||||||
            # Create a closure to give select_subprotocol access to connection.
 | 
					 | 
				
			||||||
            protocol_select_subprotocol: (
 | 
					 | 
				
			||||||
                Callable[
 | 
					 | 
				
			||||||
                    [ServerProtocol, Sequence[Subprotocol]],
 | 
					 | 
				
			||||||
                    Subprotocol | None,
 | 
					 | 
				
			||||||
                ]
 | 
					 | 
				
			||||||
                | None
 | 
					 | 
				
			||||||
            ) = None
 | 
					 | 
				
			||||||
            if select_subprotocol is not None:
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                def protocol_select_subprotocol(
 | 
					 | 
				
			||||||
                    protocol: ServerProtocol,
 | 
					 | 
				
			||||||
                    subprotocols: Sequence[Subprotocol],
 | 
					 | 
				
			||||||
                ) -> Subprotocol | None:
 | 
					 | 
				
			||||||
                    # mypy doesn't know that select_subprotocol is immutable.
 | 
					 | 
				
			||||||
                    assert select_subprotocol is not None
 | 
					 | 
				
			||||||
                    # Ensure this function is only used in the intended context.
 | 
					 | 
				
			||||||
                    assert protocol is connection.protocol
 | 
					 | 
				
			||||||
                    return select_subprotocol(connection, subprotocols)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            # This is a protocol in the Sans-I/O implementation of websockets.
 | 
					 | 
				
			||||||
            protocol = ServerProtocol(
 | 
					 | 
				
			||||||
                origins=origins,
 | 
					 | 
				
			||||||
                extensions=extensions,
 | 
					 | 
				
			||||||
                subprotocols=subprotocols,
 | 
					 | 
				
			||||||
                select_subprotocol=protocol_select_subprotocol,
 | 
					 | 
				
			||||||
                max_size=max_size,
 | 
					 | 
				
			||||||
                logger=logger,
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
            # This is a connection in websockets and a protocol in asyncio.
 | 
					 | 
				
			||||||
            connection = create_connection(
 | 
					 | 
				
			||||||
                protocol,
 | 
					 | 
				
			||||||
                self.server,
 | 
					 | 
				
			||||||
                ping_interval=ping_interval,
 | 
					 | 
				
			||||||
                ping_timeout=ping_timeout,
 | 
					 | 
				
			||||||
                close_timeout=close_timeout,
 | 
					 | 
				
			||||||
                max_queue=max_queue,
 | 
					 | 
				
			||||||
                write_limit=write_limit,
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
            return connection
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        loop = asyncio.get_running_loop()
 | 
					 | 
				
			||||||
        if kwargs.pop("unix", False):
 | 
					 | 
				
			||||||
            self.create_server = loop.create_unix_server(factory, **kwargs)
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            # mypy cannot tell that kwargs must provide sock when port is None.
 | 
					 | 
				
			||||||
            self.create_server = loop.create_server(factory, host, port, **kwargs)  # type: ignore[arg-type]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # async with serve(...) as ...: ...
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    async def __aenter__(self) -> Server:
 | 
					 | 
				
			||||||
        return await self
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    async def __aexit__(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        exc_type: type[BaseException] | None,
 | 
					 | 
				
			||||||
        exc_value: BaseException | None,
 | 
					 | 
				
			||||||
        traceback: TracebackType | None,
 | 
					 | 
				
			||||||
    ) -> None:
 | 
					 | 
				
			||||||
        self.server.close()
 | 
					 | 
				
			||||||
        await self.server.wait_closed()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # ... = await serve(...)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __await__(self) -> Generator[Any, None, Server]:
 | 
					 | 
				
			||||||
        # Create a suitable iterator by calling __await__ on a coroutine.
 | 
					 | 
				
			||||||
        return self.__await_impl__().__await__()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    async def __await_impl__(self) -> Server:
 | 
					 | 
				
			||||||
        server = await self.create_server
 | 
					 | 
				
			||||||
        self.server.wrap(server)
 | 
					 | 
				
			||||||
        return self.server
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # ... = yield from serve(...) - remove when dropping Python < 3.10
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    __iter__ = __await__
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def unix_serve(
 | 
					 | 
				
			||||||
    handler: Callable[[ServerConnection], Awaitable[None]],
 | 
					 | 
				
			||||||
    path: str | None = None,
 | 
					 | 
				
			||||||
    **kwargs: Any,
 | 
					 | 
				
			||||||
) -> Awaitable[Server]:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Create a WebSocket server listening on a Unix socket.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    This function is identical to :func:`serve`, except the ``host`` and
 | 
					 | 
				
			||||||
    ``port`` arguments are replaced by ``path``. It's only available on Unix.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    It's useful for deploying a server behind a reverse proxy such as nginx.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Args:
 | 
					 | 
				
			||||||
        handler: Connection handler. It receives the WebSocket connection,
 | 
					 | 
				
			||||||
            which is a :class:`ServerConnection`, in argument.
 | 
					 | 
				
			||||||
        path: File system path to the Unix socket.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    return serve(handler, unix=True, path=path, **kwargs)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def is_credentials(credentials: Any) -> bool:
 | 
					 | 
				
			||||||
    try:
 | 
					 | 
				
			||||||
        username, password = credentials
 | 
					 | 
				
			||||||
    except (TypeError, ValueError):
 | 
					 | 
				
			||||||
        return False
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        return isinstance(username, str) and isinstance(password, str)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def basic_auth(
 | 
					 | 
				
			||||||
    realm: str = "",
 | 
					 | 
				
			||||||
    credentials: tuple[str, str] | Iterable[tuple[str, str]] | None = None,
 | 
					 | 
				
			||||||
    check_credentials: Callable[[str, str], Awaitable[bool] | bool] | None = None,
 | 
					 | 
				
			||||||
) -> Callable[[ServerConnection, Request], Awaitable[Response | None]]:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Factory for ``process_request`` to enforce HTTP Basic Authentication.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    :func:`basic_auth` is designed to integrate with :func:`serve` as follows::
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        from websockets.asyncio.server import basic_auth, serve
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        async with serve(
 | 
					 | 
				
			||||||
            ...,
 | 
					 | 
				
			||||||
            process_request=basic_auth(
 | 
					 | 
				
			||||||
                realm="my dev server",
 | 
					 | 
				
			||||||
                credentials=("hello", "iloveyou"),
 | 
					 | 
				
			||||||
            ),
 | 
					 | 
				
			||||||
        ):
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    If authentication succeeds, the connection's ``username`` attribute is set.
 | 
					 | 
				
			||||||
    If it fails, the server responds with an HTTP 401 Unauthorized status.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    One of ``credentials`` or ``check_credentials`` must be provided; not both.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Args:
 | 
					 | 
				
			||||||
        realm: Scope of protection. It should contain only ASCII characters
 | 
					 | 
				
			||||||
            because the encoding of non-ASCII characters is undefined. Refer to
 | 
					 | 
				
			||||||
            section 2.2 of :rfc:`7235` for details.
 | 
					 | 
				
			||||||
        credentials: Hard coded authorized credentials. It can be a
 | 
					 | 
				
			||||||
            ``(username, password)`` pair or a list of such pairs.
 | 
					 | 
				
			||||||
        check_credentials: Function or coroutine that verifies credentials.
 | 
					 | 
				
			||||||
            It receives ``username`` and ``password`` arguments and returns
 | 
					 | 
				
			||||||
            whether they're valid.
 | 
					 | 
				
			||||||
    Raises:
 | 
					 | 
				
			||||||
        TypeError: If ``credentials`` or ``check_credentials`` is wrong.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    if (credentials is None) == (check_credentials is None):
 | 
					 | 
				
			||||||
        raise TypeError("provide either credentials or check_credentials")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if credentials is not None:
 | 
					 | 
				
			||||||
        if is_credentials(credentials):
 | 
					 | 
				
			||||||
            credentials_list = [cast(Tuple[str, str], credentials)]
 | 
					 | 
				
			||||||
        elif isinstance(credentials, Iterable):
 | 
					 | 
				
			||||||
            credentials_list = list(cast(Iterable[Tuple[str, str]], credentials))
 | 
					 | 
				
			||||||
            if not all(is_credentials(item) for item in credentials_list):
 | 
					 | 
				
			||||||
                raise TypeError(f"invalid credentials argument: {credentials}")
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            raise TypeError(f"invalid credentials argument: {credentials}")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        credentials_dict = dict(credentials_list)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        def check_credentials(username: str, password: str) -> bool:
 | 
					 | 
				
			||||||
            try:
 | 
					 | 
				
			||||||
                expected_password = credentials_dict[username]
 | 
					 | 
				
			||||||
            except KeyError:
 | 
					 | 
				
			||||||
                return False
 | 
					 | 
				
			||||||
            return hmac.compare_digest(expected_password, password)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    assert check_credentials is not None  # help mypy
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    async def process_request(
 | 
					 | 
				
			||||||
        connection: ServerConnection,
 | 
					 | 
				
			||||||
        request: Request,
 | 
					 | 
				
			||||||
    ) -> Response | None:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Perform HTTP Basic Authentication.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        If it succeeds, set the connection's ``username`` attribute and return
 | 
					 | 
				
			||||||
        :obj:`None`. If it fails, return an HTTP 401 Unauthorized responss.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        try:
 | 
					 | 
				
			||||||
            authorization = request.headers["Authorization"]
 | 
					 | 
				
			||||||
        except KeyError:
 | 
					 | 
				
			||||||
            response = connection.respond(
 | 
					 | 
				
			||||||
                http.HTTPStatus.UNAUTHORIZED,
 | 
					 | 
				
			||||||
                "Missing credentials\n",
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
            response.headers["WWW-Authenticate"] = build_www_authenticate_basic(realm)
 | 
					 | 
				
			||||||
            return response
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        try:
 | 
					 | 
				
			||||||
            username, password = parse_authorization_basic(authorization)
 | 
					 | 
				
			||||||
        except InvalidHeader:
 | 
					 | 
				
			||||||
            response = connection.respond(
 | 
					 | 
				
			||||||
                http.HTTPStatus.UNAUTHORIZED,
 | 
					 | 
				
			||||||
                "Unsupported credentials\n",
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
            response.headers["WWW-Authenticate"] = build_www_authenticate_basic(realm)
 | 
					 | 
				
			||||||
            return response
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        valid_credentials = check_credentials(username, password)
 | 
					 | 
				
			||||||
        if isinstance(valid_credentials, Awaitable):
 | 
					 | 
				
			||||||
            valid_credentials = await valid_credentials
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if not valid_credentials:
 | 
					 | 
				
			||||||
            response = connection.respond(
 | 
					 | 
				
			||||||
                http.HTTPStatus.UNAUTHORIZED,
 | 
					 | 
				
			||||||
                "Invalid credentials\n",
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
            response.headers["WWW-Authenticate"] = build_www_authenticate_basic(realm)
 | 
					 | 
				
			||||||
            return response
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        connection.username = username
 | 
					 | 
				
			||||||
        return None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    return process_request
 | 
					 | 
				
			||||||
@@ -1,6 +0,0 @@
 | 
				
			|||||||
from __future__ import annotations
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# See #940 for why lazy_import isn't used here for backwards compatibility.
 | 
					 | 
				
			||||||
# See #1400 for why listing compatibility imports in __all__ helps PyCharm.
 | 
					 | 
				
			||||||
from .legacy.auth import *
 | 
					 | 
				
			||||||
from .legacy.auth import __all__  # noqa: F401
 | 
					 | 
				
			||||||
@@ -1,393 +0,0 @@
 | 
				
			|||||||
from __future__ import annotations
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import os
 | 
					 | 
				
			||||||
import random
 | 
					 | 
				
			||||||
import warnings
 | 
					 | 
				
			||||||
from typing import Any, Generator, Sequence
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from .datastructures import Headers, MultipleValuesError
 | 
					 | 
				
			||||||
from .exceptions import (
 | 
					 | 
				
			||||||
    InvalidHandshake,
 | 
					 | 
				
			||||||
    InvalidHeader,
 | 
					 | 
				
			||||||
    InvalidHeaderValue,
 | 
					 | 
				
			||||||
    InvalidStatus,
 | 
					 | 
				
			||||||
    InvalidUpgrade,
 | 
					 | 
				
			||||||
    NegotiationError,
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
from .extensions import ClientExtensionFactory, Extension
 | 
					 | 
				
			||||||
from .headers import (
 | 
					 | 
				
			||||||
    build_authorization_basic,
 | 
					 | 
				
			||||||
    build_extension,
 | 
					 | 
				
			||||||
    build_host,
 | 
					 | 
				
			||||||
    build_subprotocol,
 | 
					 | 
				
			||||||
    parse_connection,
 | 
					 | 
				
			||||||
    parse_extension,
 | 
					 | 
				
			||||||
    parse_subprotocol,
 | 
					 | 
				
			||||||
    parse_upgrade,
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
from .http11 import Request, Response
 | 
					 | 
				
			||||||
from .protocol import CLIENT, CONNECTING, OPEN, Protocol, State
 | 
					 | 
				
			||||||
from .typing import (
 | 
					 | 
				
			||||||
    ConnectionOption,
 | 
					 | 
				
			||||||
    ExtensionHeader,
 | 
					 | 
				
			||||||
    LoggerLike,
 | 
					 | 
				
			||||||
    Origin,
 | 
					 | 
				
			||||||
    Subprotocol,
 | 
					 | 
				
			||||||
    UpgradeProtocol,
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
from .uri import WebSocketURI
 | 
					 | 
				
			||||||
from .utils import accept_key, generate_key
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# See #940 for why lazy_import isn't used here for backwards compatibility.
 | 
					 | 
				
			||||||
# See #1400 for why listing compatibility imports in __all__ helps PyCharm.
 | 
					 | 
				
			||||||
from .legacy.client import *  # isort:skip  # noqa: I001
 | 
					 | 
				
			||||||
from .legacy.client import __all__ as legacy__all__
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
__all__ = ["ClientProtocol"] + legacy__all__
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class ClientProtocol(Protocol):
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Sans-I/O implementation of a WebSocket client connection.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Args:
 | 
					 | 
				
			||||||
        wsuri: URI of the WebSocket server, parsed
 | 
					 | 
				
			||||||
            with :func:`~websockets.uri.parse_uri`.
 | 
					 | 
				
			||||||
        origin: Value of the ``Origin`` header. This is useful when connecting
 | 
					 | 
				
			||||||
            to a server that validates the ``Origin`` header to defend against
 | 
					 | 
				
			||||||
            Cross-Site WebSocket Hijacking attacks.
 | 
					 | 
				
			||||||
        extensions: List of supported extensions, in order in which they
 | 
					 | 
				
			||||||
            should be tried.
 | 
					 | 
				
			||||||
        subprotocols: List of supported subprotocols, in order of decreasing
 | 
					 | 
				
			||||||
            preference.
 | 
					 | 
				
			||||||
        state: Initial state of the WebSocket connection.
 | 
					 | 
				
			||||||
        max_size: Maximum size of incoming messages in bytes;
 | 
					 | 
				
			||||||
            :obj:`None` disables the limit.
 | 
					 | 
				
			||||||
        logger: Logger for this connection;
 | 
					 | 
				
			||||||
            defaults to ``logging.getLogger("websockets.client")``;
 | 
					 | 
				
			||||||
            see the :doc:`logging guide <../../topics/logging>` for details.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        wsuri: WebSocketURI,
 | 
					 | 
				
			||||||
        *,
 | 
					 | 
				
			||||||
        origin: Origin | None = None,
 | 
					 | 
				
			||||||
        extensions: Sequence[ClientExtensionFactory] | None = None,
 | 
					 | 
				
			||||||
        subprotocols: Sequence[Subprotocol] | None = None,
 | 
					 | 
				
			||||||
        state: State = CONNECTING,
 | 
					 | 
				
			||||||
        max_size: int | None = 2**20,
 | 
					 | 
				
			||||||
        logger: LoggerLike | None = None,
 | 
					 | 
				
			||||||
    ) -> None:
 | 
					 | 
				
			||||||
        super().__init__(
 | 
					 | 
				
			||||||
            side=CLIENT,
 | 
					 | 
				
			||||||
            state=state,
 | 
					 | 
				
			||||||
            max_size=max_size,
 | 
					 | 
				
			||||||
            logger=logger,
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        self.wsuri = wsuri
 | 
					 | 
				
			||||||
        self.origin = origin
 | 
					 | 
				
			||||||
        self.available_extensions = extensions
 | 
					 | 
				
			||||||
        self.available_subprotocols = subprotocols
 | 
					 | 
				
			||||||
        self.key = generate_key()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def connect(self) -> Request:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Create a handshake request to open a connection.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        You must send the handshake request with :meth:`send_request`.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        You can modify it before sending it, for example to add HTTP headers.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Returns:
 | 
					 | 
				
			||||||
            WebSocket handshake request event to send to the server.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        headers = Headers()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        headers["Host"] = build_host(
 | 
					 | 
				
			||||||
            self.wsuri.host, self.wsuri.port, self.wsuri.secure
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if self.wsuri.user_info:
 | 
					 | 
				
			||||||
            headers["Authorization"] = build_authorization_basic(*self.wsuri.user_info)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if self.origin is not None:
 | 
					 | 
				
			||||||
            headers["Origin"] = self.origin
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        headers["Upgrade"] = "websocket"
 | 
					 | 
				
			||||||
        headers["Connection"] = "Upgrade"
 | 
					 | 
				
			||||||
        headers["Sec-WebSocket-Key"] = self.key
 | 
					 | 
				
			||||||
        headers["Sec-WebSocket-Version"] = "13"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if self.available_extensions is not None:
 | 
					 | 
				
			||||||
            extensions_header = build_extension(
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    (extension_factory.name, extension_factory.get_request_params())
 | 
					 | 
				
			||||||
                    for extension_factory in self.available_extensions
 | 
					 | 
				
			||||||
                ]
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
            headers["Sec-WebSocket-Extensions"] = extensions_header
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if self.available_subprotocols is not None:
 | 
					 | 
				
			||||||
            protocol_header = build_subprotocol(self.available_subprotocols)
 | 
					 | 
				
			||||||
            headers["Sec-WebSocket-Protocol"] = protocol_header
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return Request(self.wsuri.resource_name, headers)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def process_response(self, response: Response) -> None:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Check a handshake response.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Args:
 | 
					 | 
				
			||||||
            request: WebSocket handshake response received from the server.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Raises:
 | 
					 | 
				
			||||||
            InvalidHandshake: If the handshake response is invalid.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if response.status_code != 101:
 | 
					 | 
				
			||||||
            raise InvalidStatus(response)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        headers = response.headers
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        connection: list[ConnectionOption] = sum(
 | 
					 | 
				
			||||||
            [parse_connection(value) for value in headers.get_all("Connection")], []
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if not any(value.lower() == "upgrade" for value in connection):
 | 
					 | 
				
			||||||
            raise InvalidUpgrade(
 | 
					 | 
				
			||||||
                "Connection", ", ".join(connection) if connection else None
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        upgrade: list[UpgradeProtocol] = sum(
 | 
					 | 
				
			||||||
            [parse_upgrade(value) for value in headers.get_all("Upgrade")], []
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # For compatibility with non-strict implementations, ignore case when
 | 
					 | 
				
			||||||
        # checking the Upgrade header. It's supposed to be 'WebSocket'.
 | 
					 | 
				
			||||||
        if not (len(upgrade) == 1 and upgrade[0].lower() == "websocket"):
 | 
					 | 
				
			||||||
            raise InvalidUpgrade("Upgrade", ", ".join(upgrade) if upgrade else None)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        try:
 | 
					 | 
				
			||||||
            s_w_accept = headers["Sec-WebSocket-Accept"]
 | 
					 | 
				
			||||||
        except KeyError as exc:
 | 
					 | 
				
			||||||
            raise InvalidHeader("Sec-WebSocket-Accept") from exc
 | 
					 | 
				
			||||||
        except MultipleValuesError as exc:
 | 
					 | 
				
			||||||
            raise InvalidHeader("Sec-WebSocket-Accept", "multiple values") from exc
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if s_w_accept != accept_key(self.key):
 | 
					 | 
				
			||||||
            raise InvalidHeaderValue("Sec-WebSocket-Accept", s_w_accept)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.extensions = self.process_extensions(headers)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.subprotocol = self.process_subprotocol(headers)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def process_extensions(self, headers: Headers) -> list[Extension]:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Handle the Sec-WebSocket-Extensions HTTP response header.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Check that each extension is supported, as well as its parameters.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        :rfc:`6455` leaves the rules up to the specification of each
 | 
					 | 
				
			||||||
        extension.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        To provide this level of flexibility, for each extension accepted by
 | 
					 | 
				
			||||||
        the server, we check for a match with each extension available in the
 | 
					 | 
				
			||||||
        client configuration. If no match is found, an exception is raised.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        If several variants of the same extension are accepted by the server,
 | 
					 | 
				
			||||||
        it may be configured several times, which won't make sense in general.
 | 
					 | 
				
			||||||
        Extensions must implement their own requirements. For this purpose,
 | 
					 | 
				
			||||||
        the list of previously accepted extensions is provided.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Other requirements, for example related to mandatory extensions or the
 | 
					 | 
				
			||||||
        order of extensions, may be implemented by overriding this method.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Args:
 | 
					 | 
				
			||||||
            headers: WebSocket handshake response headers.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Returns:
 | 
					 | 
				
			||||||
            List of accepted extensions.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Raises:
 | 
					 | 
				
			||||||
            InvalidHandshake: To abort the handshake.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        accepted_extensions: list[Extension] = []
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        extensions = headers.get_all("Sec-WebSocket-Extensions")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if extensions:
 | 
					 | 
				
			||||||
            if self.available_extensions is None:
 | 
					 | 
				
			||||||
                raise NegotiationError("no extensions supported")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            parsed_extensions: list[ExtensionHeader] = sum(
 | 
					 | 
				
			||||||
                [parse_extension(header_value) for header_value in extensions], []
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            for name, response_params in parsed_extensions:
 | 
					 | 
				
			||||||
                for extension_factory in self.available_extensions:
 | 
					 | 
				
			||||||
                    # Skip non-matching extensions based on their name.
 | 
					 | 
				
			||||||
                    if extension_factory.name != name:
 | 
					 | 
				
			||||||
                        continue
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                    # Skip non-matching extensions based on their params.
 | 
					 | 
				
			||||||
                    try:
 | 
					 | 
				
			||||||
                        extension = extension_factory.process_response_params(
 | 
					 | 
				
			||||||
                            response_params, accepted_extensions
 | 
					 | 
				
			||||||
                        )
 | 
					 | 
				
			||||||
                    except NegotiationError:
 | 
					 | 
				
			||||||
                        continue
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                    # Add matching extension to the final list.
 | 
					 | 
				
			||||||
                    accepted_extensions.append(extension)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                    # Break out of the loop once we have a match.
 | 
					 | 
				
			||||||
                    break
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                # If we didn't break from the loop, no extension in our list
 | 
					 | 
				
			||||||
                # matched what the server sent. Fail the connection.
 | 
					 | 
				
			||||||
                else:
 | 
					 | 
				
			||||||
                    raise NegotiationError(
 | 
					 | 
				
			||||||
                        f"Unsupported extension: "
 | 
					 | 
				
			||||||
                        f"name = {name}, params = {response_params}"
 | 
					 | 
				
			||||||
                    )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return accepted_extensions
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def process_subprotocol(self, headers: Headers) -> Subprotocol | None:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Handle the Sec-WebSocket-Protocol HTTP response header.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        If provided, check that it contains exactly one supported subprotocol.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Args:
 | 
					 | 
				
			||||||
            headers: WebSocket handshake response headers.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Returns:
 | 
					 | 
				
			||||||
           Subprotocol, if one was selected.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        subprotocol: Subprotocol | None = None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        subprotocols = headers.get_all("Sec-WebSocket-Protocol")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if subprotocols:
 | 
					 | 
				
			||||||
            if self.available_subprotocols is None:
 | 
					 | 
				
			||||||
                raise NegotiationError("no subprotocols supported")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            parsed_subprotocols: Sequence[Subprotocol] = sum(
 | 
					 | 
				
			||||||
                [parse_subprotocol(header_value) for header_value in subprotocols], []
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            if len(parsed_subprotocols) > 1:
 | 
					 | 
				
			||||||
                raise InvalidHeader(
 | 
					 | 
				
			||||||
                    "Sec-WebSocket-Protocol",
 | 
					 | 
				
			||||||
                    f"multiple values: {', '.join(parsed_subprotocols)}",
 | 
					 | 
				
			||||||
                )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            subprotocol = parsed_subprotocols[0]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            if subprotocol not in self.available_subprotocols:
 | 
					 | 
				
			||||||
                raise NegotiationError(f"unsupported subprotocol: {subprotocol}")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return subprotocol
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def send_request(self, request: Request) -> None:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Send a handshake request to the server.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Args:
 | 
					 | 
				
			||||||
            request: WebSocket handshake request event.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        if self.debug:
 | 
					 | 
				
			||||||
            self.logger.debug("> GET %s HTTP/1.1", request.path)
 | 
					 | 
				
			||||||
            for key, value in request.headers.raw_items():
 | 
					 | 
				
			||||||
                self.logger.debug("> %s: %s", key, value)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.writes.append(request.serialize())
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def parse(self) -> Generator[None, None, None]:
 | 
					 | 
				
			||||||
        if self.state is CONNECTING:
 | 
					 | 
				
			||||||
            try:
 | 
					 | 
				
			||||||
                response = yield from Response.parse(
 | 
					 | 
				
			||||||
                    self.reader.read_line,
 | 
					 | 
				
			||||||
                    self.reader.read_exact,
 | 
					 | 
				
			||||||
                    self.reader.read_to_eof,
 | 
					 | 
				
			||||||
                )
 | 
					 | 
				
			||||||
            except Exception as exc:
 | 
					 | 
				
			||||||
                self.handshake_exc = exc
 | 
					 | 
				
			||||||
                self.send_eof()
 | 
					 | 
				
			||||||
                self.parser = self.discard()
 | 
					 | 
				
			||||||
                next(self.parser)  # start coroutine
 | 
					 | 
				
			||||||
                yield
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            if self.debug:
 | 
					 | 
				
			||||||
                code, phrase = response.status_code, response.reason_phrase
 | 
					 | 
				
			||||||
                self.logger.debug("< HTTP/1.1 %d %s", code, phrase)
 | 
					 | 
				
			||||||
                for key, value in response.headers.raw_items():
 | 
					 | 
				
			||||||
                    self.logger.debug("< %s: %s", key, value)
 | 
					 | 
				
			||||||
                if response.body is not None:
 | 
					 | 
				
			||||||
                    self.logger.debug("< [body] (%d bytes)", len(response.body))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            try:
 | 
					 | 
				
			||||||
                self.process_response(response)
 | 
					 | 
				
			||||||
            except InvalidHandshake as exc:
 | 
					 | 
				
			||||||
                response._exception = exc
 | 
					 | 
				
			||||||
                self.events.append(response)
 | 
					 | 
				
			||||||
                self.handshake_exc = exc
 | 
					 | 
				
			||||||
                self.send_eof()
 | 
					 | 
				
			||||||
                self.parser = self.discard()
 | 
					 | 
				
			||||||
                next(self.parser)  # start coroutine
 | 
					 | 
				
			||||||
                yield
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            assert self.state is CONNECTING
 | 
					 | 
				
			||||||
            self.state = OPEN
 | 
					 | 
				
			||||||
            self.events.append(response)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        yield from super().parse()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class ClientConnection(ClientProtocol):
 | 
					 | 
				
			||||||
    def __init__(self, *args: Any, **kwargs: Any) -> None:
 | 
					 | 
				
			||||||
        warnings.warn(  # deprecated in 11.0 - 2023-04-02
 | 
					 | 
				
			||||||
            "ClientConnection was renamed to ClientProtocol",
 | 
					 | 
				
			||||||
            DeprecationWarning,
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        super().__init__(*args, **kwargs)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
BACKOFF_INITIAL_DELAY = float(os.environ.get("WEBSOCKETS_BACKOFF_INITIAL_DELAY", "5"))
 | 
					 | 
				
			||||||
BACKOFF_MIN_DELAY = float(os.environ.get("WEBSOCKETS_BACKOFF_MIN_DELAY", "3.1"))
 | 
					 | 
				
			||||||
BACKOFF_MAX_DELAY = float(os.environ.get("WEBSOCKETS_BACKOFF_MAX_DELAY", "90.0"))
 | 
					 | 
				
			||||||
BACKOFF_FACTOR = float(os.environ.get("WEBSOCKETS_BACKOFF_FACTOR", "1.618"))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def backoff(
 | 
					 | 
				
			||||||
    initial_delay: float = BACKOFF_INITIAL_DELAY,
 | 
					 | 
				
			||||||
    min_delay: float = BACKOFF_MIN_DELAY,
 | 
					 | 
				
			||||||
    max_delay: float = BACKOFF_MAX_DELAY,
 | 
					 | 
				
			||||||
    factor: float = BACKOFF_FACTOR,
 | 
					 | 
				
			||||||
) -> Generator[float, None, None]:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Generate a series of backoff delays between reconnection attempts.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Yields:
 | 
					 | 
				
			||||||
        How many seconds to wait before retrying to connect.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    # Add a random initial delay between 0 and 5 seconds.
 | 
					 | 
				
			||||||
    # See 7.2.3. Recovering from Abnormal Closure in RFC 6455.
 | 
					 | 
				
			||||||
    yield random.random() * initial_delay
 | 
					 | 
				
			||||||
    delay = min_delay
 | 
					 | 
				
			||||||
    while delay < max_delay:
 | 
					 | 
				
			||||||
        yield delay
 | 
					 | 
				
			||||||
        delay *= factor
 | 
					 | 
				
			||||||
    while True:
 | 
					 | 
				
			||||||
        yield max_delay
 | 
					 | 
				
			||||||
@@ -1,12 +0,0 @@
 | 
				
			|||||||
from __future__ import annotations
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import warnings
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from .protocol import SEND_EOF, Protocol as Connection, Side, State  # noqa: F401
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
warnings.warn(  # deprecated in 11.0 - 2023-04-02
 | 
					 | 
				
			||||||
    "websockets.connection was renamed to websockets.protocol "
 | 
					 | 
				
			||||||
    "and Connection was renamed to Protocol",
 | 
					 | 
				
			||||||
    DeprecationWarning,
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
@@ -1,192 +0,0 @@
 | 
				
			|||||||
from __future__ import annotations
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from typing import (
 | 
					 | 
				
			||||||
    Any,
 | 
					 | 
				
			||||||
    Iterable,
 | 
					 | 
				
			||||||
    Iterator,
 | 
					 | 
				
			||||||
    Mapping,
 | 
					 | 
				
			||||||
    MutableMapping,
 | 
					 | 
				
			||||||
    Protocol,
 | 
					 | 
				
			||||||
    Tuple,
 | 
					 | 
				
			||||||
    Union,
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
__all__ = ["Headers", "HeadersLike", "MultipleValuesError"]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class MultipleValuesError(LookupError):
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Exception raised when :class:`Headers` has multiple values for a key.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __str__(self) -> str:
 | 
					 | 
				
			||||||
        # Implement the same logic as KeyError_str in Objects/exceptions.c.
 | 
					 | 
				
			||||||
        if len(self.args) == 1:
 | 
					 | 
				
			||||||
            return repr(self.args[0])
 | 
					 | 
				
			||||||
        return super().__str__()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class Headers(MutableMapping[str, str]):
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Efficient data structure for manipulating HTTP headers.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    A :class:`list` of ``(name, values)`` is inefficient for lookups.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    A :class:`dict` doesn't suffice because header names are case-insensitive
 | 
					 | 
				
			||||||
    and multiple occurrences of headers with the same name are possible.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    :class:`Headers` stores HTTP headers in a hybrid data structure to provide
 | 
					 | 
				
			||||||
    efficient insertions and lookups while preserving the original data.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    In order to account for multiple values with minimal hassle,
 | 
					 | 
				
			||||||
    :class:`Headers` follows this logic:
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    - When getting a header with ``headers[name]``:
 | 
					 | 
				
			||||||
        - if there's no value, :exc:`KeyError` is raised;
 | 
					 | 
				
			||||||
        - if there's exactly one value, it's returned;
 | 
					 | 
				
			||||||
        - if there's more than one value, :exc:`MultipleValuesError` is raised.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    - When setting a header with ``headers[name] = value``, the value is
 | 
					 | 
				
			||||||
      appended to the list of values for that header.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    - When deleting a header with ``del headers[name]``, all values for that
 | 
					 | 
				
			||||||
      header are removed (this is slow).
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Other methods for manipulating headers are consistent with this logic.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    As long as no header occurs multiple times, :class:`Headers` behaves like
 | 
					 | 
				
			||||||
    :class:`dict`, except keys are lower-cased to provide case-insensitivity.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Two methods support manipulating multiple values explicitly:
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    - :meth:`get_all` returns a list of all values for a header;
 | 
					 | 
				
			||||||
    - :meth:`raw_items` returns an iterator of ``(name, values)`` pairs.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    __slots__ = ["_dict", "_list"]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Like dict, Headers accepts an optional "mapping or iterable" argument.
 | 
					 | 
				
			||||||
    def __init__(self, *args: HeadersLike, **kwargs: str) -> None:
 | 
					 | 
				
			||||||
        self._dict: dict[str, list[str]] = {}
 | 
					 | 
				
			||||||
        self._list: list[tuple[str, str]] = []
 | 
					 | 
				
			||||||
        self.update(*args, **kwargs)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __str__(self) -> str:
 | 
					 | 
				
			||||||
        return "".join(f"{key}: {value}\r\n" for key, value in self._list) + "\r\n"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __repr__(self) -> str:
 | 
					 | 
				
			||||||
        return f"{self.__class__.__name__}({self._list!r})"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def copy(self) -> Headers:
 | 
					 | 
				
			||||||
        copy = self.__class__()
 | 
					 | 
				
			||||||
        copy._dict = self._dict.copy()
 | 
					 | 
				
			||||||
        copy._list = self._list.copy()
 | 
					 | 
				
			||||||
        return copy
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def serialize(self) -> bytes:
 | 
					 | 
				
			||||||
        # Since headers only contain ASCII characters, we can keep this simple.
 | 
					 | 
				
			||||||
        return str(self).encode()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Collection methods
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __contains__(self, key: object) -> bool:
 | 
					 | 
				
			||||||
        return isinstance(key, str) and key.lower() in self._dict
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __iter__(self) -> Iterator[str]:
 | 
					 | 
				
			||||||
        return iter(self._dict)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __len__(self) -> int:
 | 
					 | 
				
			||||||
        return len(self._dict)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # MutableMapping methods
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __getitem__(self, key: str) -> str:
 | 
					 | 
				
			||||||
        value = self._dict[key.lower()]
 | 
					 | 
				
			||||||
        if len(value) == 1:
 | 
					 | 
				
			||||||
            return value[0]
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            raise MultipleValuesError(key)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __setitem__(self, key: str, value: str) -> None:
 | 
					 | 
				
			||||||
        self._dict.setdefault(key.lower(), []).append(value)
 | 
					 | 
				
			||||||
        self._list.append((key, value))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __delitem__(self, key: str) -> None:
 | 
					 | 
				
			||||||
        key_lower = key.lower()
 | 
					 | 
				
			||||||
        self._dict.__delitem__(key_lower)
 | 
					 | 
				
			||||||
        # This is inefficient. Fortunately deleting HTTP headers is uncommon.
 | 
					 | 
				
			||||||
        self._list = [(k, v) for k, v in self._list if k.lower() != key_lower]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __eq__(self, other: Any) -> bool:
 | 
					 | 
				
			||||||
        if not isinstance(other, Headers):
 | 
					 | 
				
			||||||
            return NotImplemented
 | 
					 | 
				
			||||||
        return self._dict == other._dict
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def clear(self) -> None:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Remove all headers.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        self._dict = {}
 | 
					 | 
				
			||||||
        self._list = []
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def update(self, *args: HeadersLike, **kwargs: str) -> None:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Update from a :class:`Headers` instance and/or keyword arguments.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        args = tuple(
 | 
					 | 
				
			||||||
            arg.raw_items() if isinstance(arg, Headers) else arg for arg in args
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        super().update(*args, **kwargs)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Methods for handling multiple values
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def get_all(self, key: str) -> list[str]:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Return the (possibly empty) list of all values for a header.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Args:
 | 
					 | 
				
			||||||
            key: Header name.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        return self._dict.get(key.lower(), [])
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def raw_items(self) -> Iterator[tuple[str, str]]:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Return an iterator of all values as ``(name, value)`` pairs.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        return iter(self._list)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# copy of _typeshed.SupportsKeysAndGetItem.
 | 
					 | 
				
			||||||
class SupportsKeysAndGetItem(Protocol):  # pragma: no cover
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Dict-like types with ``keys() -> str`` and ``__getitem__(key: str) -> str`` methods.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def keys(self) -> Iterable[str]: ...
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __getitem__(self, key: str) -> str: ...
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# Change to Headers | Mapping[str, str] | ... when dropping Python < 3.10.
 | 
					 | 
				
			||||||
HeadersLike = Union[
 | 
					 | 
				
			||||||
    Headers,
 | 
					 | 
				
			||||||
    Mapping[str, str],
 | 
					 | 
				
			||||||
    # Change to tuple[str, str] when dropping Python < 3.9.
 | 
					 | 
				
			||||||
    Iterable[Tuple[str, str]],
 | 
					 | 
				
			||||||
    SupportsKeysAndGetItem,
 | 
					 | 
				
			||||||
]
 | 
					 | 
				
			||||||
"""
 | 
					 | 
				
			||||||
Types accepted where :class:`Headers` is expected.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
In addition to :class:`Headers` itself, this includes dict-like types where both
 | 
					 | 
				
			||||||
keys and values are :class:`str`.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
"""
 | 
					 | 
				
			||||||
@@ -1,392 +0,0 @@
 | 
				
			|||||||
"""
 | 
					 | 
				
			||||||
:mod:`websockets.exceptions` defines the following hierarchy of exceptions.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
* :exc:`WebSocketException`
 | 
					 | 
				
			||||||
    * :exc:`ConnectionClosed`
 | 
					 | 
				
			||||||
        * :exc:`ConnectionClosedOK`
 | 
					 | 
				
			||||||
        * :exc:`ConnectionClosedError`
 | 
					 | 
				
			||||||
    * :exc:`InvalidURI`
 | 
					 | 
				
			||||||
    * :exc:`InvalidHandshake`
 | 
					 | 
				
			||||||
        * :exc:`SecurityError`
 | 
					 | 
				
			||||||
        * :exc:`InvalidMessage` (legacy)
 | 
					 | 
				
			||||||
        * :exc:`InvalidStatus`
 | 
					 | 
				
			||||||
        * :exc:`InvalidStatusCode` (legacy)
 | 
					 | 
				
			||||||
        * :exc:`InvalidHeader`
 | 
					 | 
				
			||||||
            * :exc:`InvalidHeaderFormat`
 | 
					 | 
				
			||||||
            * :exc:`InvalidHeaderValue`
 | 
					 | 
				
			||||||
            * :exc:`InvalidOrigin`
 | 
					 | 
				
			||||||
            * :exc:`InvalidUpgrade`
 | 
					 | 
				
			||||||
        * :exc:`NegotiationError`
 | 
					 | 
				
			||||||
            * :exc:`DuplicateParameter`
 | 
					 | 
				
			||||||
            * :exc:`InvalidParameterName`
 | 
					 | 
				
			||||||
            * :exc:`InvalidParameterValue`
 | 
					 | 
				
			||||||
        * :exc:`AbortHandshake` (legacy)
 | 
					 | 
				
			||||||
        * :exc:`RedirectHandshake` (legacy)
 | 
					 | 
				
			||||||
    * :exc:`ProtocolError` (Sans-I/O)
 | 
					 | 
				
			||||||
    * :exc:`PayloadTooBig` (Sans-I/O)
 | 
					 | 
				
			||||||
    * :exc:`InvalidState` (Sans-I/O)
 | 
					 | 
				
			||||||
    * :exc:`ConcurrencyError`
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
"""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from __future__ import annotations
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import typing
 | 
					 | 
				
			||||||
import warnings
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from .imports import lazy_import
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
__all__ = [
 | 
					 | 
				
			||||||
    "WebSocketException",
 | 
					 | 
				
			||||||
    "ConnectionClosed",
 | 
					 | 
				
			||||||
    "ConnectionClosedOK",
 | 
					 | 
				
			||||||
    "ConnectionClosedError",
 | 
					 | 
				
			||||||
    "InvalidURI",
 | 
					 | 
				
			||||||
    "InvalidHandshake",
 | 
					 | 
				
			||||||
    "SecurityError",
 | 
					 | 
				
			||||||
    "InvalidMessage",
 | 
					 | 
				
			||||||
    "InvalidStatus",
 | 
					 | 
				
			||||||
    "InvalidStatusCode",
 | 
					 | 
				
			||||||
    "InvalidHeader",
 | 
					 | 
				
			||||||
    "InvalidHeaderFormat",
 | 
					 | 
				
			||||||
    "InvalidHeaderValue",
 | 
					 | 
				
			||||||
    "InvalidOrigin",
 | 
					 | 
				
			||||||
    "InvalidUpgrade",
 | 
					 | 
				
			||||||
    "NegotiationError",
 | 
					 | 
				
			||||||
    "DuplicateParameter",
 | 
					 | 
				
			||||||
    "InvalidParameterName",
 | 
					 | 
				
			||||||
    "InvalidParameterValue",
 | 
					 | 
				
			||||||
    "AbortHandshake",
 | 
					 | 
				
			||||||
    "RedirectHandshake",
 | 
					 | 
				
			||||||
    "ProtocolError",
 | 
					 | 
				
			||||||
    "WebSocketProtocolError",
 | 
					 | 
				
			||||||
    "PayloadTooBig",
 | 
					 | 
				
			||||||
    "InvalidState",
 | 
					 | 
				
			||||||
    "ConcurrencyError",
 | 
					 | 
				
			||||||
]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class WebSocketException(Exception):
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Base class for all exceptions defined by websockets.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class ConnectionClosed(WebSocketException):
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Raised when trying to interact with a closed connection.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Attributes:
 | 
					 | 
				
			||||||
        rcvd: If a close frame was received, its code and reason are available
 | 
					 | 
				
			||||||
            in ``rcvd.code`` and ``rcvd.reason``.
 | 
					 | 
				
			||||||
        sent: If a close frame was sent, its code and reason are available
 | 
					 | 
				
			||||||
            in ``sent.code`` and ``sent.reason``.
 | 
					 | 
				
			||||||
        rcvd_then_sent: If close frames were received and sent, this attribute
 | 
					 | 
				
			||||||
            tells in which order this happened, from the perspective of this
 | 
					 | 
				
			||||||
            side of the connection.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        rcvd: frames.Close | None,
 | 
					 | 
				
			||||||
        sent: frames.Close | None,
 | 
					 | 
				
			||||||
        rcvd_then_sent: bool | None = None,
 | 
					 | 
				
			||||||
    ) -> None:
 | 
					 | 
				
			||||||
        self.rcvd = rcvd
 | 
					 | 
				
			||||||
        self.sent = sent
 | 
					 | 
				
			||||||
        self.rcvd_then_sent = rcvd_then_sent
 | 
					 | 
				
			||||||
        assert (self.rcvd_then_sent is None) == (self.rcvd is None or self.sent is None)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __str__(self) -> str:
 | 
					 | 
				
			||||||
        if self.rcvd is None:
 | 
					 | 
				
			||||||
            if self.sent is None:
 | 
					 | 
				
			||||||
                return "no close frame received or sent"
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                return f"sent {self.sent}; no close frame received"
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            if self.sent is None:
 | 
					 | 
				
			||||||
                return f"received {self.rcvd}; no close frame sent"
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                if self.rcvd_then_sent:
 | 
					 | 
				
			||||||
                    return f"received {self.rcvd}; then sent {self.sent}"
 | 
					 | 
				
			||||||
                else:
 | 
					 | 
				
			||||||
                    return f"sent {self.sent}; then received {self.rcvd}"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # code and reason attributes are provided for backwards-compatibility
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @property
 | 
					 | 
				
			||||||
    def code(self) -> int:
 | 
					 | 
				
			||||||
        warnings.warn(  # deprecated in 13.1
 | 
					 | 
				
			||||||
            "ConnectionClosed.code is deprecated; "
 | 
					 | 
				
			||||||
            "use Protocol.close_code or ConnectionClosed.rcvd.code",
 | 
					 | 
				
			||||||
            DeprecationWarning,
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        if self.rcvd is None:
 | 
					 | 
				
			||||||
            return frames.CloseCode.ABNORMAL_CLOSURE
 | 
					 | 
				
			||||||
        return self.rcvd.code
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @property
 | 
					 | 
				
			||||||
    def reason(self) -> str:
 | 
					 | 
				
			||||||
        warnings.warn(  # deprecated in 13.1
 | 
					 | 
				
			||||||
            "ConnectionClosed.reason is deprecated; "
 | 
					 | 
				
			||||||
            "use Protocol.close_reason or ConnectionClosed.rcvd.reason",
 | 
					 | 
				
			||||||
            DeprecationWarning,
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        if self.rcvd is None:
 | 
					 | 
				
			||||||
            return ""
 | 
					 | 
				
			||||||
        return self.rcvd.reason
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class ConnectionClosedOK(ConnectionClosed):
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Like :exc:`ConnectionClosed`, when the connection terminated properly.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    A close code with code 1000 (OK) or 1001 (going away) or without a code was
 | 
					 | 
				
			||||||
    received and sent.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class ConnectionClosedError(ConnectionClosed):
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Like :exc:`ConnectionClosed`, when the connection terminated with an error.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    A close frame with a code other than 1000 (OK) or 1001 (going away) was
 | 
					 | 
				
			||||||
    received or sent, or the closing handshake didn't complete properly.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class InvalidURI(WebSocketException):
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Raised when connecting to a URI that isn't a valid WebSocket URI.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(self, uri: str, msg: str) -> None:
 | 
					 | 
				
			||||||
        self.uri = uri
 | 
					 | 
				
			||||||
        self.msg = msg
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __str__(self) -> str:
 | 
					 | 
				
			||||||
        return f"{self.uri} isn't a valid URI: {self.msg}"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class InvalidHandshake(WebSocketException):
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Base class for exceptions raised when the opening handshake fails.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class SecurityError(InvalidHandshake):
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Raised when a handshake request or response breaks a security rule.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Security limits can be configured with :doc:`environment variables
 | 
					 | 
				
			||||||
    <../reference/variables>`.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class InvalidStatus(InvalidHandshake):
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Raised when a handshake response rejects the WebSocket upgrade.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(self, response: http11.Response) -> None:
 | 
					 | 
				
			||||||
        self.response = response
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __str__(self) -> str:
 | 
					 | 
				
			||||||
        return (
 | 
					 | 
				
			||||||
            "server rejected WebSocket connection: "
 | 
					 | 
				
			||||||
            f"HTTP {self.response.status_code:d}"
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class InvalidHeader(InvalidHandshake):
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Raised when an HTTP header doesn't have a valid format or value.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(self, name: str, value: str | None = None) -> None:
 | 
					 | 
				
			||||||
        self.name = name
 | 
					 | 
				
			||||||
        self.value = value
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __str__(self) -> str:
 | 
					 | 
				
			||||||
        if self.value is None:
 | 
					 | 
				
			||||||
            return f"missing {self.name} header"
 | 
					 | 
				
			||||||
        elif self.value == "":
 | 
					 | 
				
			||||||
            return f"empty {self.name} header"
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            return f"invalid {self.name} header: {self.value}"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class InvalidHeaderFormat(InvalidHeader):
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Raised when an HTTP header cannot be parsed.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    The format of the header doesn't match the grammar for that header.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(self, name: str, error: str, header: str, pos: int) -> None:
 | 
					 | 
				
			||||||
        super().__init__(name, f"{error} at {pos} in {header}")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class InvalidHeaderValue(InvalidHeader):
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Raised when an HTTP header has a wrong value.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    The format of the header is correct but the value isn't acceptable.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class InvalidOrigin(InvalidHeader):
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Raised when the Origin header in a request isn't allowed.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(self, origin: str | None) -> None:
 | 
					 | 
				
			||||||
        super().__init__("Origin", origin)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class InvalidUpgrade(InvalidHeader):
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Raised when the Upgrade or Connection header isn't correct.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class NegotiationError(InvalidHandshake):
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Raised when negotiating an extension or a subprotocol fails.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class DuplicateParameter(NegotiationError):
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Raised when a parameter name is repeated in an extension header.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(self, name: str) -> None:
 | 
					 | 
				
			||||||
        self.name = name
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __str__(self) -> str:
 | 
					 | 
				
			||||||
        return f"duplicate parameter: {self.name}"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class InvalidParameterName(NegotiationError):
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Raised when a parameter name in an extension header is invalid.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(self, name: str) -> None:
 | 
					 | 
				
			||||||
        self.name = name
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __str__(self) -> str:
 | 
					 | 
				
			||||||
        return f"invalid parameter name: {self.name}"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class InvalidParameterValue(NegotiationError):
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Raised when a parameter value in an extension header is invalid.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(self, name: str, value: str | None) -> None:
 | 
					 | 
				
			||||||
        self.name = name
 | 
					 | 
				
			||||||
        self.value = value
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __str__(self) -> str:
 | 
					 | 
				
			||||||
        if self.value is None:
 | 
					 | 
				
			||||||
            return f"missing value for parameter {self.name}"
 | 
					 | 
				
			||||||
        elif self.value == "":
 | 
					 | 
				
			||||||
            return f"empty value for parameter {self.name}"
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            return f"invalid value for parameter {self.name}: {self.value}"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class ProtocolError(WebSocketException):
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Raised when receiving or sending a frame that breaks the protocol.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    The Sans-I/O implementation raises this exception when:
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    * receiving or sending a frame that contains invalid data;
 | 
					 | 
				
			||||||
    * receiving or sending an invalid sequence of frames.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class PayloadTooBig(WebSocketException):
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Raised when parsing a frame with a payload that exceeds the maximum size.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    The Sans-I/O layer uses this exception internally. It doesn't bubble up to
 | 
					 | 
				
			||||||
    the I/O layer.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    The :meth:`~websockets.extensions.Extension.decode` method of extensions
 | 
					 | 
				
			||||||
    must raise :exc:`PayloadTooBig` if decoding a frame would exceed the limit.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class InvalidState(WebSocketException, AssertionError):
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Raised when sending a frame is forbidden in the current state.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Specifically, the Sans-I/O layer raises this exception when:
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    * sending a data frame to a connection in a state other
 | 
					 | 
				
			||||||
      :attr:`~websockets.protocol.State.OPEN`;
 | 
					 | 
				
			||||||
    * sending a control frame to a connection in a state other than
 | 
					 | 
				
			||||||
      :attr:`~websockets.protocol.State.OPEN` or
 | 
					 | 
				
			||||||
      :attr:`~websockets.protocol.State.CLOSING`.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class ConcurrencyError(WebSocketException, RuntimeError):
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Raised when receiving or sending messages concurrently.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    WebSocket is a connection-oriented protocol. Reads must be serialized; so
 | 
					 | 
				
			||||||
    must be writes. However, reading and writing concurrently is possible.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# When type checking, import non-deprecated aliases eagerly. Else, import on demand.
 | 
					 | 
				
			||||||
if typing.TYPE_CHECKING:
 | 
					 | 
				
			||||||
    from .legacy.exceptions import (
 | 
					 | 
				
			||||||
        AbortHandshake,
 | 
					 | 
				
			||||||
        InvalidMessage,
 | 
					 | 
				
			||||||
        InvalidStatusCode,
 | 
					 | 
				
			||||||
        RedirectHandshake,
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    WebSocketProtocolError = ProtocolError
 | 
					 | 
				
			||||||
else:
 | 
					 | 
				
			||||||
    lazy_import(
 | 
					 | 
				
			||||||
        globals(),
 | 
					 | 
				
			||||||
        aliases={
 | 
					 | 
				
			||||||
            "AbortHandshake": ".legacy.exceptions",
 | 
					 | 
				
			||||||
            "InvalidMessage": ".legacy.exceptions",
 | 
					 | 
				
			||||||
            "InvalidStatusCode": ".legacy.exceptions",
 | 
					 | 
				
			||||||
            "RedirectHandshake": ".legacy.exceptions",
 | 
					 | 
				
			||||||
            "WebSocketProtocolError": ".legacy.exceptions",
 | 
					 | 
				
			||||||
        },
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# At the bottom to break import cycles created by type annotations.
 | 
					 | 
				
			||||||
from . import frames, http11  # noqa: E402
 | 
					 | 
				
			||||||
@@ -1,4 +0,0 @@
 | 
				
			|||||||
from .base import *
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
__all__ = ["Extension", "ClientExtensionFactory", "ServerExtensionFactory"]
 | 
					 | 
				
			||||||
@@ -1,123 +0,0 @@
 | 
				
			|||||||
from __future__ import annotations
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from typing import Sequence
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from ..frames import Frame
 | 
					 | 
				
			||||||
from ..typing import ExtensionName, ExtensionParameter
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
__all__ = ["Extension", "ClientExtensionFactory", "ServerExtensionFactory"]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class Extension:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Base class for extensions.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    name: ExtensionName
 | 
					 | 
				
			||||||
    """Extension identifier."""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def decode(self, frame: Frame, *, max_size: int | None = None) -> Frame:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Decode an incoming frame.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Args:
 | 
					 | 
				
			||||||
            frame: Incoming frame.
 | 
					 | 
				
			||||||
            max_size: Maximum payload size in bytes.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Returns:
 | 
					 | 
				
			||||||
            Decoded frame.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Raises:
 | 
					 | 
				
			||||||
            PayloadTooBig: If decoding the payload exceeds ``max_size``.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        raise NotImplementedError
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def encode(self, frame: Frame) -> Frame:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Encode an outgoing frame.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Args:
 | 
					 | 
				
			||||||
            frame: Outgoing frame.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Returns:
 | 
					 | 
				
			||||||
            Encoded frame.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        raise NotImplementedError
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class ClientExtensionFactory:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Base class for client-side extension factories.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    name: ExtensionName
 | 
					 | 
				
			||||||
    """Extension identifier."""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def get_request_params(self) -> list[ExtensionParameter]:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Build parameters to send to the server for this extension.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Returns:
 | 
					 | 
				
			||||||
            Parameters to send to the server.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        raise NotImplementedError
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def process_response_params(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        params: Sequence[ExtensionParameter],
 | 
					 | 
				
			||||||
        accepted_extensions: Sequence[Extension],
 | 
					 | 
				
			||||||
    ) -> Extension:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Process parameters received from the server.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Args:
 | 
					 | 
				
			||||||
            params: Parameters received from the server for this extension.
 | 
					 | 
				
			||||||
            accepted_extensions: List of previously accepted extensions.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Returns:
 | 
					 | 
				
			||||||
            An extension instance.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Raises:
 | 
					 | 
				
			||||||
            NegotiationError: If parameters aren't acceptable.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        raise NotImplementedError
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class ServerExtensionFactory:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Base class for server-side extension factories.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    name: ExtensionName
 | 
					 | 
				
			||||||
    """Extension identifier."""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def process_request_params(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        params: Sequence[ExtensionParameter],
 | 
					 | 
				
			||||||
        accepted_extensions: Sequence[Extension],
 | 
					 | 
				
			||||||
    ) -> tuple[list[ExtensionParameter], Extension]:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Process parameters received from the client.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Args:
 | 
					 | 
				
			||||||
            params: Parameters received from the client for this extension.
 | 
					 | 
				
			||||||
            accepted_extensions: List of previously accepted extensions.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Returns:
 | 
					 | 
				
			||||||
            To accept the offer, parameters to send to the client for this
 | 
					 | 
				
			||||||
            extension and an extension instance.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Raises:
 | 
					 | 
				
			||||||
            NegotiationError: To reject the offer, if parameters received from
 | 
					 | 
				
			||||||
                the client aren't acceptable.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        raise NotImplementedError
 | 
					 | 
				
			||||||
@@ -1,670 +0,0 @@
 | 
				
			|||||||
from __future__ import annotations
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import dataclasses
 | 
					 | 
				
			||||||
import zlib
 | 
					 | 
				
			||||||
from typing import Any, Sequence
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from .. import frames
 | 
					 | 
				
			||||||
from ..exceptions import (
 | 
					 | 
				
			||||||
    DuplicateParameter,
 | 
					 | 
				
			||||||
    InvalidParameterName,
 | 
					 | 
				
			||||||
    InvalidParameterValue,
 | 
					 | 
				
			||||||
    NegotiationError,
 | 
					 | 
				
			||||||
    PayloadTooBig,
 | 
					 | 
				
			||||||
    ProtocolError,
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
from ..typing import ExtensionName, ExtensionParameter
 | 
					 | 
				
			||||||
from .base import ClientExtensionFactory, Extension, ServerExtensionFactory
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
__all__ = [
 | 
					 | 
				
			||||||
    "PerMessageDeflate",
 | 
					 | 
				
			||||||
    "ClientPerMessageDeflateFactory",
 | 
					 | 
				
			||||||
    "enable_client_permessage_deflate",
 | 
					 | 
				
			||||||
    "ServerPerMessageDeflateFactory",
 | 
					 | 
				
			||||||
    "enable_server_permessage_deflate",
 | 
					 | 
				
			||||||
]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
_EMPTY_UNCOMPRESSED_BLOCK = b"\x00\x00\xff\xff"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
_MAX_WINDOW_BITS_VALUES = [str(bits) for bits in range(8, 16)]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class PerMessageDeflate(Extension):
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Per-Message Deflate extension.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    name = ExtensionName("permessage-deflate")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        remote_no_context_takeover: bool,
 | 
					 | 
				
			||||||
        local_no_context_takeover: bool,
 | 
					 | 
				
			||||||
        remote_max_window_bits: int,
 | 
					 | 
				
			||||||
        local_max_window_bits: int,
 | 
					 | 
				
			||||||
        compress_settings: dict[Any, Any] | None = None,
 | 
					 | 
				
			||||||
    ) -> None:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Configure the Per-Message Deflate extension.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        if compress_settings is None:
 | 
					 | 
				
			||||||
            compress_settings = {}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        assert remote_no_context_takeover in [False, True]
 | 
					 | 
				
			||||||
        assert local_no_context_takeover in [False, True]
 | 
					 | 
				
			||||||
        assert 8 <= remote_max_window_bits <= 15
 | 
					 | 
				
			||||||
        assert 8 <= local_max_window_bits <= 15
 | 
					 | 
				
			||||||
        assert "wbits" not in compress_settings
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.remote_no_context_takeover = remote_no_context_takeover
 | 
					 | 
				
			||||||
        self.local_no_context_takeover = local_no_context_takeover
 | 
					 | 
				
			||||||
        self.remote_max_window_bits = remote_max_window_bits
 | 
					 | 
				
			||||||
        self.local_max_window_bits = local_max_window_bits
 | 
					 | 
				
			||||||
        self.compress_settings = compress_settings
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if not self.remote_no_context_takeover:
 | 
					 | 
				
			||||||
            self.decoder = zlib.decompressobj(wbits=-self.remote_max_window_bits)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if not self.local_no_context_takeover:
 | 
					 | 
				
			||||||
            self.encoder = zlib.compressobj(
 | 
					 | 
				
			||||||
                wbits=-self.local_max_window_bits,
 | 
					 | 
				
			||||||
                **self.compress_settings,
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # To handle continuation frames properly, we must keep track of
 | 
					 | 
				
			||||||
        # whether that initial frame was encoded.
 | 
					 | 
				
			||||||
        self.decode_cont_data = False
 | 
					 | 
				
			||||||
        # There's no need for self.encode_cont_data because we always encode
 | 
					 | 
				
			||||||
        # outgoing frames, so it would always be True.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __repr__(self) -> str:
 | 
					 | 
				
			||||||
        return (
 | 
					 | 
				
			||||||
            f"PerMessageDeflate("
 | 
					 | 
				
			||||||
            f"remote_no_context_takeover={self.remote_no_context_takeover}, "
 | 
					 | 
				
			||||||
            f"local_no_context_takeover={self.local_no_context_takeover}, "
 | 
					 | 
				
			||||||
            f"remote_max_window_bits={self.remote_max_window_bits}, "
 | 
					 | 
				
			||||||
            f"local_max_window_bits={self.local_max_window_bits})"
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def decode(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        frame: frames.Frame,
 | 
					 | 
				
			||||||
        *,
 | 
					 | 
				
			||||||
        max_size: int | None = None,
 | 
					 | 
				
			||||||
    ) -> frames.Frame:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Decode an incoming frame.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        # Skip control frames.
 | 
					 | 
				
			||||||
        if frame.opcode in frames.CTRL_OPCODES:
 | 
					 | 
				
			||||||
            return frame
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Handle continuation data frames:
 | 
					 | 
				
			||||||
        # - skip if the message isn't encoded
 | 
					 | 
				
			||||||
        # - reset "decode continuation data" flag if it's a final frame
 | 
					 | 
				
			||||||
        if frame.opcode is frames.OP_CONT:
 | 
					 | 
				
			||||||
            if not self.decode_cont_data:
 | 
					 | 
				
			||||||
                return frame
 | 
					 | 
				
			||||||
            if frame.fin:
 | 
					 | 
				
			||||||
                self.decode_cont_data = False
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Handle text and binary data frames:
 | 
					 | 
				
			||||||
        # - skip if the message isn't encoded
 | 
					 | 
				
			||||||
        # - unset the rsv1 flag on the first frame of a compressed message
 | 
					 | 
				
			||||||
        # - set "decode continuation data" flag if it's a non-final frame
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            if not frame.rsv1:
 | 
					 | 
				
			||||||
                return frame
 | 
					 | 
				
			||||||
            frame = dataclasses.replace(frame, rsv1=False)
 | 
					 | 
				
			||||||
            if not frame.fin:
 | 
					 | 
				
			||||||
                self.decode_cont_data = True
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            # Re-initialize per-message decoder.
 | 
					 | 
				
			||||||
            if self.remote_no_context_takeover:
 | 
					 | 
				
			||||||
                self.decoder = zlib.decompressobj(wbits=-self.remote_max_window_bits)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Uncompress data. Protect against zip bombs by preventing zlib from
 | 
					 | 
				
			||||||
        # decompressing more than max_length bytes (except when the limit is
 | 
					 | 
				
			||||||
        # disabled with max_size = None).
 | 
					 | 
				
			||||||
        data = frame.data
 | 
					 | 
				
			||||||
        if frame.fin:
 | 
					 | 
				
			||||||
            data += _EMPTY_UNCOMPRESSED_BLOCK
 | 
					 | 
				
			||||||
        max_length = 0 if max_size is None else max_size
 | 
					 | 
				
			||||||
        try:
 | 
					 | 
				
			||||||
            data = self.decoder.decompress(data, max_length)
 | 
					 | 
				
			||||||
        except zlib.error as exc:
 | 
					 | 
				
			||||||
            raise ProtocolError("decompression failed") from exc
 | 
					 | 
				
			||||||
        if self.decoder.unconsumed_tail:
 | 
					 | 
				
			||||||
            raise PayloadTooBig(f"over size limit (? > {max_size} bytes)")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Allow garbage collection of the decoder if it won't be reused.
 | 
					 | 
				
			||||||
        if frame.fin and self.remote_no_context_takeover:
 | 
					 | 
				
			||||||
            del self.decoder
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return dataclasses.replace(frame, data=data)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def encode(self, frame: frames.Frame) -> frames.Frame:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Encode an outgoing frame.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        # Skip control frames.
 | 
					 | 
				
			||||||
        if frame.opcode in frames.CTRL_OPCODES:
 | 
					 | 
				
			||||||
            return frame
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Since we always encode messages, there's no "encode continuation
 | 
					 | 
				
			||||||
        # data" flag similar to "decode continuation data" at this time.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if frame.opcode is not frames.OP_CONT:
 | 
					 | 
				
			||||||
            # Set the rsv1 flag on the first frame of a compressed message.
 | 
					 | 
				
			||||||
            frame = dataclasses.replace(frame, rsv1=True)
 | 
					 | 
				
			||||||
            # Re-initialize per-message decoder.
 | 
					 | 
				
			||||||
            if self.local_no_context_takeover:
 | 
					 | 
				
			||||||
                self.encoder = zlib.compressobj(
 | 
					 | 
				
			||||||
                    wbits=-self.local_max_window_bits,
 | 
					 | 
				
			||||||
                    **self.compress_settings,
 | 
					 | 
				
			||||||
                )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Compress data.
 | 
					 | 
				
			||||||
        data = self.encoder.compress(frame.data) + self.encoder.flush(zlib.Z_SYNC_FLUSH)
 | 
					 | 
				
			||||||
        if frame.fin and data.endswith(_EMPTY_UNCOMPRESSED_BLOCK):
 | 
					 | 
				
			||||||
            data = data[:-4]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Allow garbage collection of the encoder if it won't be reused.
 | 
					 | 
				
			||||||
        if frame.fin and self.local_no_context_takeover:
 | 
					 | 
				
			||||||
            del self.encoder
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return dataclasses.replace(frame, data=data)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def _build_parameters(
 | 
					 | 
				
			||||||
    server_no_context_takeover: bool,
 | 
					 | 
				
			||||||
    client_no_context_takeover: bool,
 | 
					 | 
				
			||||||
    server_max_window_bits: int | None,
 | 
					 | 
				
			||||||
    client_max_window_bits: int | bool | None,
 | 
					 | 
				
			||||||
) -> list[ExtensionParameter]:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Build a list of ``(name, value)`` pairs for some compression parameters.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    params: list[ExtensionParameter] = []
 | 
					 | 
				
			||||||
    if server_no_context_takeover:
 | 
					 | 
				
			||||||
        params.append(("server_no_context_takeover", None))
 | 
					 | 
				
			||||||
    if client_no_context_takeover:
 | 
					 | 
				
			||||||
        params.append(("client_no_context_takeover", None))
 | 
					 | 
				
			||||||
    if server_max_window_bits:
 | 
					 | 
				
			||||||
        params.append(("server_max_window_bits", str(server_max_window_bits)))
 | 
					 | 
				
			||||||
    if client_max_window_bits is True:  # only in handshake requests
 | 
					 | 
				
			||||||
        params.append(("client_max_window_bits", None))
 | 
					 | 
				
			||||||
    elif client_max_window_bits:
 | 
					 | 
				
			||||||
        params.append(("client_max_window_bits", str(client_max_window_bits)))
 | 
					 | 
				
			||||||
    return params
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def _extract_parameters(
 | 
					 | 
				
			||||||
    params: Sequence[ExtensionParameter], *, is_server: bool
 | 
					 | 
				
			||||||
) -> tuple[bool, bool, int | None, int | bool | None]:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Extract compression parameters from a list of ``(name, value)`` pairs.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    If ``is_server`` is :obj:`True`, ``client_max_window_bits`` may be
 | 
					 | 
				
			||||||
    provided without a value. This is only allowed in handshake requests.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    server_no_context_takeover: bool = False
 | 
					 | 
				
			||||||
    client_no_context_takeover: bool = False
 | 
					 | 
				
			||||||
    server_max_window_bits: int | None = None
 | 
					 | 
				
			||||||
    client_max_window_bits: int | bool | None = None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    for name, value in params:
 | 
					 | 
				
			||||||
        if name == "server_no_context_takeover":
 | 
					 | 
				
			||||||
            if server_no_context_takeover:
 | 
					 | 
				
			||||||
                raise DuplicateParameter(name)
 | 
					 | 
				
			||||||
            if value is None:
 | 
					 | 
				
			||||||
                server_no_context_takeover = True
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                raise InvalidParameterValue(name, value)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        elif name == "client_no_context_takeover":
 | 
					 | 
				
			||||||
            if client_no_context_takeover:
 | 
					 | 
				
			||||||
                raise DuplicateParameter(name)
 | 
					 | 
				
			||||||
            if value is None:
 | 
					 | 
				
			||||||
                client_no_context_takeover = True
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                raise InvalidParameterValue(name, value)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        elif name == "server_max_window_bits":
 | 
					 | 
				
			||||||
            if server_max_window_bits is not None:
 | 
					 | 
				
			||||||
                raise DuplicateParameter(name)
 | 
					 | 
				
			||||||
            if value in _MAX_WINDOW_BITS_VALUES:
 | 
					 | 
				
			||||||
                server_max_window_bits = int(value)
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                raise InvalidParameterValue(name, value)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        elif name == "client_max_window_bits":
 | 
					 | 
				
			||||||
            if client_max_window_bits is not None:
 | 
					 | 
				
			||||||
                raise DuplicateParameter(name)
 | 
					 | 
				
			||||||
            if is_server and value is None:  # only in handshake requests
 | 
					 | 
				
			||||||
                client_max_window_bits = True
 | 
					 | 
				
			||||||
            elif value in _MAX_WINDOW_BITS_VALUES:
 | 
					 | 
				
			||||||
                client_max_window_bits = int(value)
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                raise InvalidParameterValue(name, value)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            raise InvalidParameterName(name)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    return (
 | 
					 | 
				
			||||||
        server_no_context_takeover,
 | 
					 | 
				
			||||||
        client_no_context_takeover,
 | 
					 | 
				
			||||||
        server_max_window_bits,
 | 
					 | 
				
			||||||
        client_max_window_bits,
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class ClientPerMessageDeflateFactory(ClientExtensionFactory):
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Client-side extension factory for the Per-Message Deflate extension.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Parameters behave as described in `section 7.1 of RFC 7692`_.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    .. _section 7.1 of RFC 7692: https://datatracker.ietf.org/doc/html/rfc7692#section-7.1
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Set them to :obj:`True` to include them in the negotiation offer without a
 | 
					 | 
				
			||||||
    value or to an integer value to include them with this value.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Args:
 | 
					 | 
				
			||||||
        server_no_context_takeover: Prevent server from using context takeover.
 | 
					 | 
				
			||||||
        client_no_context_takeover: Prevent client from using context takeover.
 | 
					 | 
				
			||||||
        server_max_window_bits: Maximum size of the server's LZ77 sliding window
 | 
					 | 
				
			||||||
            in bits, between 8 and 15.
 | 
					 | 
				
			||||||
        client_max_window_bits: Maximum size of the client's LZ77 sliding window
 | 
					 | 
				
			||||||
            in bits, between 8 and 15, or :obj:`True` to indicate support without
 | 
					 | 
				
			||||||
            setting a limit.
 | 
					 | 
				
			||||||
        compress_settings: Additional keyword arguments for :func:`zlib.compressobj`,
 | 
					 | 
				
			||||||
            excluding ``wbits``.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    name = ExtensionName("permessage-deflate")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        server_no_context_takeover: bool = False,
 | 
					 | 
				
			||||||
        client_no_context_takeover: bool = False,
 | 
					 | 
				
			||||||
        server_max_window_bits: int | None = None,
 | 
					 | 
				
			||||||
        client_max_window_bits: int | bool | None = True,
 | 
					 | 
				
			||||||
        compress_settings: dict[str, Any] | None = None,
 | 
					 | 
				
			||||||
    ) -> None:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Configure the Per-Message Deflate extension factory.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        if not (server_max_window_bits is None or 8 <= server_max_window_bits <= 15):
 | 
					 | 
				
			||||||
            raise ValueError("server_max_window_bits must be between 8 and 15")
 | 
					 | 
				
			||||||
        if not (
 | 
					 | 
				
			||||||
            client_max_window_bits is None
 | 
					 | 
				
			||||||
            or client_max_window_bits is True
 | 
					 | 
				
			||||||
            or 8 <= client_max_window_bits <= 15
 | 
					 | 
				
			||||||
        ):
 | 
					 | 
				
			||||||
            raise ValueError("client_max_window_bits must be between 8 and 15")
 | 
					 | 
				
			||||||
        if compress_settings is not None and "wbits" in compress_settings:
 | 
					 | 
				
			||||||
            raise ValueError(
 | 
					 | 
				
			||||||
                "compress_settings must not include wbits, "
 | 
					 | 
				
			||||||
                "set client_max_window_bits instead"
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.server_no_context_takeover = server_no_context_takeover
 | 
					 | 
				
			||||||
        self.client_no_context_takeover = client_no_context_takeover
 | 
					 | 
				
			||||||
        self.server_max_window_bits = server_max_window_bits
 | 
					 | 
				
			||||||
        self.client_max_window_bits = client_max_window_bits
 | 
					 | 
				
			||||||
        self.compress_settings = compress_settings
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def get_request_params(self) -> list[ExtensionParameter]:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Build request parameters.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        return _build_parameters(
 | 
					 | 
				
			||||||
            self.server_no_context_takeover,
 | 
					 | 
				
			||||||
            self.client_no_context_takeover,
 | 
					 | 
				
			||||||
            self.server_max_window_bits,
 | 
					 | 
				
			||||||
            self.client_max_window_bits,
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def process_response_params(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        params: Sequence[ExtensionParameter],
 | 
					 | 
				
			||||||
        accepted_extensions: Sequence[Extension],
 | 
					 | 
				
			||||||
    ) -> PerMessageDeflate:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Process response parameters.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Return an extension instance.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        if any(other.name == self.name for other in accepted_extensions):
 | 
					 | 
				
			||||||
            raise NegotiationError(f"received duplicate {self.name}")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Request parameters are available in instance variables.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Load response parameters in local variables.
 | 
					 | 
				
			||||||
        (
 | 
					 | 
				
			||||||
            server_no_context_takeover,
 | 
					 | 
				
			||||||
            client_no_context_takeover,
 | 
					 | 
				
			||||||
            server_max_window_bits,
 | 
					 | 
				
			||||||
            client_max_window_bits,
 | 
					 | 
				
			||||||
        ) = _extract_parameters(params, is_server=False)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # After comparing the request and the response, the final
 | 
					 | 
				
			||||||
        # configuration must be available in the local variables.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # server_no_context_takeover
 | 
					 | 
				
			||||||
        #
 | 
					 | 
				
			||||||
        #   Req.    Resp.   Result
 | 
					 | 
				
			||||||
        #   ------  ------  --------------------------------------------------
 | 
					 | 
				
			||||||
        #   False   False   False
 | 
					 | 
				
			||||||
        #   False   True    True
 | 
					 | 
				
			||||||
        #   True    False   Error!
 | 
					 | 
				
			||||||
        #   True    True    True
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if self.server_no_context_takeover:
 | 
					 | 
				
			||||||
            if not server_no_context_takeover:
 | 
					 | 
				
			||||||
                raise NegotiationError("expected server_no_context_takeover")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # client_no_context_takeover
 | 
					 | 
				
			||||||
        #
 | 
					 | 
				
			||||||
        #   Req.    Resp.   Result
 | 
					 | 
				
			||||||
        #   ------  ------  --------------------------------------------------
 | 
					 | 
				
			||||||
        #   False   False   False
 | 
					 | 
				
			||||||
        #   False   True    True
 | 
					 | 
				
			||||||
        #   True    False   True - must change value
 | 
					 | 
				
			||||||
        #   True    True    True
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if self.client_no_context_takeover:
 | 
					 | 
				
			||||||
            if not client_no_context_takeover:
 | 
					 | 
				
			||||||
                client_no_context_takeover = True
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # server_max_window_bits
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        #   Req.    Resp.   Result
 | 
					 | 
				
			||||||
        #   ------  ------  --------------------------------------------------
 | 
					 | 
				
			||||||
        #   None    None    None
 | 
					 | 
				
			||||||
        #   None    8≤M≤15  M
 | 
					 | 
				
			||||||
        #   8≤N≤15  None    Error!
 | 
					 | 
				
			||||||
        #   8≤N≤15  8≤M≤N   M
 | 
					 | 
				
			||||||
        #   8≤N≤15  N<M≤15  Error!
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if self.server_max_window_bits is None:
 | 
					 | 
				
			||||||
            pass
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            if server_max_window_bits is None:
 | 
					 | 
				
			||||||
                raise NegotiationError("expected server_max_window_bits")
 | 
					 | 
				
			||||||
            elif server_max_window_bits > self.server_max_window_bits:
 | 
					 | 
				
			||||||
                raise NegotiationError("unsupported server_max_window_bits")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # client_max_window_bits
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        #   Req.    Resp.   Result
 | 
					 | 
				
			||||||
        #   ------  ------  --------------------------------------------------
 | 
					 | 
				
			||||||
        #   None    None    None
 | 
					 | 
				
			||||||
        #   None    8≤M≤15  Error!
 | 
					 | 
				
			||||||
        #   True    None    None
 | 
					 | 
				
			||||||
        #   True    8≤M≤15  M
 | 
					 | 
				
			||||||
        #   8≤N≤15  None    N - must change value
 | 
					 | 
				
			||||||
        #   8≤N≤15  8≤M≤N   M
 | 
					 | 
				
			||||||
        #   8≤N≤15  N<M≤15  Error!
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if self.client_max_window_bits is None:
 | 
					 | 
				
			||||||
            if client_max_window_bits is not None:
 | 
					 | 
				
			||||||
                raise NegotiationError("unexpected client_max_window_bits")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        elif self.client_max_window_bits is True:
 | 
					 | 
				
			||||||
            pass
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            if client_max_window_bits is None:
 | 
					 | 
				
			||||||
                client_max_window_bits = self.client_max_window_bits
 | 
					 | 
				
			||||||
            elif client_max_window_bits > self.client_max_window_bits:
 | 
					 | 
				
			||||||
                raise NegotiationError("unsupported client_max_window_bits")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return PerMessageDeflate(
 | 
					 | 
				
			||||||
            server_no_context_takeover,  # remote_no_context_takeover
 | 
					 | 
				
			||||||
            client_no_context_takeover,  # local_no_context_takeover
 | 
					 | 
				
			||||||
            server_max_window_bits or 15,  # remote_max_window_bits
 | 
					 | 
				
			||||||
            client_max_window_bits or 15,  # local_max_window_bits
 | 
					 | 
				
			||||||
            self.compress_settings,
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def enable_client_permessage_deflate(
 | 
					 | 
				
			||||||
    extensions: Sequence[ClientExtensionFactory] | None,
 | 
					 | 
				
			||||||
) -> Sequence[ClientExtensionFactory]:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Enable Per-Message Deflate with default settings in client extensions.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    If the extension is already present, perhaps with non-default settings,
 | 
					 | 
				
			||||||
    the configuration isn't changed.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    if extensions is None:
 | 
					 | 
				
			||||||
        extensions = []
 | 
					 | 
				
			||||||
    if not any(
 | 
					 | 
				
			||||||
        extension_factory.name == ClientPerMessageDeflateFactory.name
 | 
					 | 
				
			||||||
        for extension_factory in extensions
 | 
					 | 
				
			||||||
    ):
 | 
					 | 
				
			||||||
        extensions = list(extensions) + [
 | 
					 | 
				
			||||||
            ClientPerMessageDeflateFactory(
 | 
					 | 
				
			||||||
                compress_settings={"memLevel": 5},
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
        ]
 | 
					 | 
				
			||||||
    return extensions
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class ServerPerMessageDeflateFactory(ServerExtensionFactory):
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Server-side extension factory for the Per-Message Deflate extension.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Parameters behave as described in `section 7.1 of RFC 7692`_.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    .. _section 7.1 of RFC 7692: https://datatracker.ietf.org/doc/html/rfc7692#section-7.1
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Set them to :obj:`True` to include them in the negotiation offer without a
 | 
					 | 
				
			||||||
    value or to an integer value to include them with this value.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Args:
 | 
					 | 
				
			||||||
        server_no_context_takeover: Prevent server from using context takeover.
 | 
					 | 
				
			||||||
        client_no_context_takeover: Prevent client from using context takeover.
 | 
					 | 
				
			||||||
        server_max_window_bits: Maximum size of the server's LZ77 sliding window
 | 
					 | 
				
			||||||
            in bits, between 8 and 15.
 | 
					 | 
				
			||||||
        client_max_window_bits: Maximum size of the client's LZ77 sliding window
 | 
					 | 
				
			||||||
            in bits, between 8 and 15.
 | 
					 | 
				
			||||||
        compress_settings: Additional keyword arguments for :func:`zlib.compressobj`,
 | 
					 | 
				
			||||||
            excluding ``wbits``.
 | 
					 | 
				
			||||||
        require_client_max_window_bits: Do not enable compression at all if
 | 
					 | 
				
			||||||
            client doesn't advertise support for ``client_max_window_bits``;
 | 
					 | 
				
			||||||
            the default behavior is to enable compression without enforcing
 | 
					 | 
				
			||||||
            ``client_max_window_bits``.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    name = ExtensionName("permessage-deflate")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        server_no_context_takeover: bool = False,
 | 
					 | 
				
			||||||
        client_no_context_takeover: bool = False,
 | 
					 | 
				
			||||||
        server_max_window_bits: int | None = None,
 | 
					 | 
				
			||||||
        client_max_window_bits: int | None = None,
 | 
					 | 
				
			||||||
        compress_settings: dict[str, Any] | None = None,
 | 
					 | 
				
			||||||
        require_client_max_window_bits: bool = False,
 | 
					 | 
				
			||||||
    ) -> None:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Configure the Per-Message Deflate extension factory.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        if not (server_max_window_bits is None or 8 <= server_max_window_bits <= 15):
 | 
					 | 
				
			||||||
            raise ValueError("server_max_window_bits must be between 8 and 15")
 | 
					 | 
				
			||||||
        if not (client_max_window_bits is None or 8 <= client_max_window_bits <= 15):
 | 
					 | 
				
			||||||
            raise ValueError("client_max_window_bits must be between 8 and 15")
 | 
					 | 
				
			||||||
        if compress_settings is not None and "wbits" in compress_settings:
 | 
					 | 
				
			||||||
            raise ValueError(
 | 
					 | 
				
			||||||
                "compress_settings must not include wbits, "
 | 
					 | 
				
			||||||
                "set server_max_window_bits instead"
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
        if client_max_window_bits is None and require_client_max_window_bits:
 | 
					 | 
				
			||||||
            raise ValueError(
 | 
					 | 
				
			||||||
                "require_client_max_window_bits is enabled, "
 | 
					 | 
				
			||||||
                "but client_max_window_bits isn't configured"
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.server_no_context_takeover = server_no_context_takeover
 | 
					 | 
				
			||||||
        self.client_no_context_takeover = client_no_context_takeover
 | 
					 | 
				
			||||||
        self.server_max_window_bits = server_max_window_bits
 | 
					 | 
				
			||||||
        self.client_max_window_bits = client_max_window_bits
 | 
					 | 
				
			||||||
        self.compress_settings = compress_settings
 | 
					 | 
				
			||||||
        self.require_client_max_window_bits = require_client_max_window_bits
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def process_request_params(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        params: Sequence[ExtensionParameter],
 | 
					 | 
				
			||||||
        accepted_extensions: Sequence[Extension],
 | 
					 | 
				
			||||||
    ) -> tuple[list[ExtensionParameter], PerMessageDeflate]:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Process request parameters.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Return response params and an extension instance.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        if any(other.name == self.name for other in accepted_extensions):
 | 
					 | 
				
			||||||
            raise NegotiationError(f"skipped duplicate {self.name}")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Load request parameters in local variables.
 | 
					 | 
				
			||||||
        (
 | 
					 | 
				
			||||||
            server_no_context_takeover,
 | 
					 | 
				
			||||||
            client_no_context_takeover,
 | 
					 | 
				
			||||||
            server_max_window_bits,
 | 
					 | 
				
			||||||
            client_max_window_bits,
 | 
					 | 
				
			||||||
        ) = _extract_parameters(params, is_server=True)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Configuration parameters are available in instance variables.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # After comparing the request and the configuration, the response must
 | 
					 | 
				
			||||||
        # be available in the local variables.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # server_no_context_takeover
 | 
					 | 
				
			||||||
        #
 | 
					 | 
				
			||||||
        #   Config  Req.    Resp.
 | 
					 | 
				
			||||||
        #   ------  ------  --------------------------------------------------
 | 
					 | 
				
			||||||
        #   False   False   False
 | 
					 | 
				
			||||||
        #   False   True    True
 | 
					 | 
				
			||||||
        #   True    False   True - must change value to True
 | 
					 | 
				
			||||||
        #   True    True    True
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if self.server_no_context_takeover:
 | 
					 | 
				
			||||||
            if not server_no_context_takeover:
 | 
					 | 
				
			||||||
                server_no_context_takeover = True
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # client_no_context_takeover
 | 
					 | 
				
			||||||
        #
 | 
					 | 
				
			||||||
        #   Config  Req.    Resp.
 | 
					 | 
				
			||||||
        #   ------  ------  --------------------------------------------------
 | 
					 | 
				
			||||||
        #   False   False   False
 | 
					 | 
				
			||||||
        #   False   True    True (or False)
 | 
					 | 
				
			||||||
        #   True    False   True - must change value to True
 | 
					 | 
				
			||||||
        #   True    True    True (or False)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if self.client_no_context_takeover:
 | 
					 | 
				
			||||||
            if not client_no_context_takeover:
 | 
					 | 
				
			||||||
                client_no_context_takeover = True
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # server_max_window_bits
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        #   Config  Req.    Resp.
 | 
					 | 
				
			||||||
        #   ------  ------  --------------------------------------------------
 | 
					 | 
				
			||||||
        #   None    None    None
 | 
					 | 
				
			||||||
        #   None    8≤M≤15  M
 | 
					 | 
				
			||||||
        #   8≤N≤15  None    N - must change value
 | 
					 | 
				
			||||||
        #   8≤N≤15  8≤M≤N   M
 | 
					 | 
				
			||||||
        #   8≤N≤15  N<M≤15  N - must change value
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if self.server_max_window_bits is None:
 | 
					 | 
				
			||||||
            pass
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            if server_max_window_bits is None:
 | 
					 | 
				
			||||||
                server_max_window_bits = self.server_max_window_bits
 | 
					 | 
				
			||||||
            elif server_max_window_bits > self.server_max_window_bits:
 | 
					 | 
				
			||||||
                server_max_window_bits = self.server_max_window_bits
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # client_max_window_bits
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        #   Config  Req.    Resp.
 | 
					 | 
				
			||||||
        #   ------  ------  --------------------------------------------------
 | 
					 | 
				
			||||||
        #   None    None    None
 | 
					 | 
				
			||||||
        #   None    True    None - must change value
 | 
					 | 
				
			||||||
        #   None    8≤M≤15  M (or None)
 | 
					 | 
				
			||||||
        #   8≤N≤15  None    None or Error!
 | 
					 | 
				
			||||||
        #   8≤N≤15  True    N - must change value
 | 
					 | 
				
			||||||
        #   8≤N≤15  8≤M≤N   M (or None)
 | 
					 | 
				
			||||||
        #   8≤N≤15  N<M≤15  N
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if self.client_max_window_bits is None:
 | 
					 | 
				
			||||||
            if client_max_window_bits is True:
 | 
					 | 
				
			||||||
                client_max_window_bits = self.client_max_window_bits
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            if client_max_window_bits is None:
 | 
					 | 
				
			||||||
                if self.require_client_max_window_bits:
 | 
					 | 
				
			||||||
                    raise NegotiationError("required client_max_window_bits")
 | 
					 | 
				
			||||||
            elif client_max_window_bits is True:
 | 
					 | 
				
			||||||
                client_max_window_bits = self.client_max_window_bits
 | 
					 | 
				
			||||||
            elif self.client_max_window_bits < client_max_window_bits:
 | 
					 | 
				
			||||||
                client_max_window_bits = self.client_max_window_bits
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return (
 | 
					 | 
				
			||||||
            _build_parameters(
 | 
					 | 
				
			||||||
                server_no_context_takeover,
 | 
					 | 
				
			||||||
                client_no_context_takeover,
 | 
					 | 
				
			||||||
                server_max_window_bits,
 | 
					 | 
				
			||||||
                client_max_window_bits,
 | 
					 | 
				
			||||||
            ),
 | 
					 | 
				
			||||||
            PerMessageDeflate(
 | 
					 | 
				
			||||||
                client_no_context_takeover,  # remote_no_context_takeover
 | 
					 | 
				
			||||||
                server_no_context_takeover,  # local_no_context_takeover
 | 
					 | 
				
			||||||
                client_max_window_bits or 15,  # remote_max_window_bits
 | 
					 | 
				
			||||||
                server_max_window_bits or 15,  # local_max_window_bits
 | 
					 | 
				
			||||||
                self.compress_settings,
 | 
					 | 
				
			||||||
            ),
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def enable_server_permessage_deflate(
 | 
					 | 
				
			||||||
    extensions: Sequence[ServerExtensionFactory] | None,
 | 
					 | 
				
			||||||
) -> Sequence[ServerExtensionFactory]:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Enable Per-Message Deflate with default settings in server extensions.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    If the extension is already present, perhaps with non-default settings,
 | 
					 | 
				
			||||||
    the configuration isn't changed.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    if extensions is None:
 | 
					 | 
				
			||||||
        extensions = []
 | 
					 | 
				
			||||||
    if not any(
 | 
					 | 
				
			||||||
        ext_factory.name == ServerPerMessageDeflateFactory.name
 | 
					 | 
				
			||||||
        for ext_factory in extensions
 | 
					 | 
				
			||||||
    ):
 | 
					 | 
				
			||||||
        extensions = list(extensions) + [
 | 
					 | 
				
			||||||
            ServerPerMessageDeflateFactory(
 | 
					 | 
				
			||||||
                server_max_window_bits=12,
 | 
					 | 
				
			||||||
                client_max_window_bits=12,
 | 
					 | 
				
			||||||
                compress_settings={"memLevel": 5},
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
        ]
 | 
					 | 
				
			||||||
    return extensions
 | 
					 | 
				
			||||||
@@ -1,429 +0,0 @@
 | 
				
			|||||||
from __future__ import annotations
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import dataclasses
 | 
					 | 
				
			||||||
import enum
 | 
					 | 
				
			||||||
import io
 | 
					 | 
				
			||||||
import os
 | 
					 | 
				
			||||||
import secrets
 | 
					 | 
				
			||||||
import struct
 | 
					 | 
				
			||||||
from typing import Callable, Generator, Sequence
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from .exceptions import PayloadTooBig, ProtocolError
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
try:
 | 
					 | 
				
			||||||
    from .speedups import apply_mask
 | 
					 | 
				
			||||||
except ImportError:
 | 
					 | 
				
			||||||
    from .utils import apply_mask
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
__all__ = [
 | 
					 | 
				
			||||||
    "Opcode",
 | 
					 | 
				
			||||||
    "OP_CONT",
 | 
					 | 
				
			||||||
    "OP_TEXT",
 | 
					 | 
				
			||||||
    "OP_BINARY",
 | 
					 | 
				
			||||||
    "OP_CLOSE",
 | 
					 | 
				
			||||||
    "OP_PING",
 | 
					 | 
				
			||||||
    "OP_PONG",
 | 
					 | 
				
			||||||
    "DATA_OPCODES",
 | 
					 | 
				
			||||||
    "CTRL_OPCODES",
 | 
					 | 
				
			||||||
    "Frame",
 | 
					 | 
				
			||||||
    "Close",
 | 
					 | 
				
			||||||
]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class Opcode(enum.IntEnum):
 | 
					 | 
				
			||||||
    """Opcode values for WebSocket frames."""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    CONT, TEXT, BINARY = 0x00, 0x01, 0x02
 | 
					 | 
				
			||||||
    CLOSE, PING, PONG = 0x08, 0x09, 0x0A
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
OP_CONT = Opcode.CONT
 | 
					 | 
				
			||||||
OP_TEXT = Opcode.TEXT
 | 
					 | 
				
			||||||
OP_BINARY = Opcode.BINARY
 | 
					 | 
				
			||||||
OP_CLOSE = Opcode.CLOSE
 | 
					 | 
				
			||||||
OP_PING = Opcode.PING
 | 
					 | 
				
			||||||
OP_PONG = Opcode.PONG
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
DATA_OPCODES = OP_CONT, OP_TEXT, OP_BINARY
 | 
					 | 
				
			||||||
CTRL_OPCODES = OP_CLOSE, OP_PING, OP_PONG
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class CloseCode(enum.IntEnum):
 | 
					 | 
				
			||||||
    """Close code values for WebSocket close frames."""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    NORMAL_CLOSURE = 1000
 | 
					 | 
				
			||||||
    GOING_AWAY = 1001
 | 
					 | 
				
			||||||
    PROTOCOL_ERROR = 1002
 | 
					 | 
				
			||||||
    UNSUPPORTED_DATA = 1003
 | 
					 | 
				
			||||||
    # 1004 is reserved
 | 
					 | 
				
			||||||
    NO_STATUS_RCVD = 1005
 | 
					 | 
				
			||||||
    ABNORMAL_CLOSURE = 1006
 | 
					 | 
				
			||||||
    INVALID_DATA = 1007
 | 
					 | 
				
			||||||
    POLICY_VIOLATION = 1008
 | 
					 | 
				
			||||||
    MESSAGE_TOO_BIG = 1009
 | 
					 | 
				
			||||||
    MANDATORY_EXTENSION = 1010
 | 
					 | 
				
			||||||
    INTERNAL_ERROR = 1011
 | 
					 | 
				
			||||||
    SERVICE_RESTART = 1012
 | 
					 | 
				
			||||||
    TRY_AGAIN_LATER = 1013
 | 
					 | 
				
			||||||
    BAD_GATEWAY = 1014
 | 
					 | 
				
			||||||
    TLS_HANDSHAKE = 1015
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# See https://www.iana.org/assignments/websocket/websocket.xhtml
 | 
					 | 
				
			||||||
CLOSE_CODE_EXPLANATIONS: dict[int, str] = {
 | 
					 | 
				
			||||||
    CloseCode.NORMAL_CLOSURE: "OK",
 | 
					 | 
				
			||||||
    CloseCode.GOING_AWAY: "going away",
 | 
					 | 
				
			||||||
    CloseCode.PROTOCOL_ERROR: "protocol error",
 | 
					 | 
				
			||||||
    CloseCode.UNSUPPORTED_DATA: "unsupported data",
 | 
					 | 
				
			||||||
    CloseCode.NO_STATUS_RCVD: "no status received [internal]",
 | 
					 | 
				
			||||||
    CloseCode.ABNORMAL_CLOSURE: "abnormal closure [internal]",
 | 
					 | 
				
			||||||
    CloseCode.INVALID_DATA: "invalid frame payload data",
 | 
					 | 
				
			||||||
    CloseCode.POLICY_VIOLATION: "policy violation",
 | 
					 | 
				
			||||||
    CloseCode.MESSAGE_TOO_BIG: "message too big",
 | 
					 | 
				
			||||||
    CloseCode.MANDATORY_EXTENSION: "mandatory extension",
 | 
					 | 
				
			||||||
    CloseCode.INTERNAL_ERROR: "internal error",
 | 
					 | 
				
			||||||
    CloseCode.SERVICE_RESTART: "service restart",
 | 
					 | 
				
			||||||
    CloseCode.TRY_AGAIN_LATER: "try again later",
 | 
					 | 
				
			||||||
    CloseCode.BAD_GATEWAY: "bad gateway",
 | 
					 | 
				
			||||||
    CloseCode.TLS_HANDSHAKE: "TLS handshake failure [internal]",
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# Close code that are allowed in a close frame.
 | 
					 | 
				
			||||||
# Using a set optimizes `code in EXTERNAL_CLOSE_CODES`.
 | 
					 | 
				
			||||||
EXTERNAL_CLOSE_CODES = {
 | 
					 | 
				
			||||||
    CloseCode.NORMAL_CLOSURE,
 | 
					 | 
				
			||||||
    CloseCode.GOING_AWAY,
 | 
					 | 
				
			||||||
    CloseCode.PROTOCOL_ERROR,
 | 
					 | 
				
			||||||
    CloseCode.UNSUPPORTED_DATA,
 | 
					 | 
				
			||||||
    CloseCode.INVALID_DATA,
 | 
					 | 
				
			||||||
    CloseCode.POLICY_VIOLATION,
 | 
					 | 
				
			||||||
    CloseCode.MESSAGE_TOO_BIG,
 | 
					 | 
				
			||||||
    CloseCode.MANDATORY_EXTENSION,
 | 
					 | 
				
			||||||
    CloseCode.INTERNAL_ERROR,
 | 
					 | 
				
			||||||
    CloseCode.SERVICE_RESTART,
 | 
					 | 
				
			||||||
    CloseCode.TRY_AGAIN_LATER,
 | 
					 | 
				
			||||||
    CloseCode.BAD_GATEWAY,
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
OK_CLOSE_CODES = {
 | 
					 | 
				
			||||||
    CloseCode.NORMAL_CLOSURE,
 | 
					 | 
				
			||||||
    CloseCode.GOING_AWAY,
 | 
					 | 
				
			||||||
    CloseCode.NO_STATUS_RCVD,
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
BytesLike = bytes, bytearray, memoryview
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@dataclasses.dataclass
 | 
					 | 
				
			||||||
class Frame:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    WebSocket frame.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Attributes:
 | 
					 | 
				
			||||||
        opcode: Opcode.
 | 
					 | 
				
			||||||
        data: Payload data.
 | 
					 | 
				
			||||||
        fin: FIN bit.
 | 
					 | 
				
			||||||
        rsv1: RSV1 bit.
 | 
					 | 
				
			||||||
        rsv2: RSV2 bit.
 | 
					 | 
				
			||||||
        rsv3: RSV3 bit.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Only these fields are needed. The MASK bit, payload length and masking-key
 | 
					 | 
				
			||||||
    are handled on the fly when parsing and serializing frames.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    opcode: Opcode
 | 
					 | 
				
			||||||
    data: bytes
 | 
					 | 
				
			||||||
    fin: bool = True
 | 
					 | 
				
			||||||
    rsv1: bool = False
 | 
					 | 
				
			||||||
    rsv2: bool = False
 | 
					 | 
				
			||||||
    rsv3: bool = False
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Configure if you want to see more in logs. Should be a multiple of 3.
 | 
					 | 
				
			||||||
    MAX_LOG_SIZE = int(os.environ.get("WEBSOCKETS_MAX_LOG_SIZE", "75"))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __str__(self) -> str:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Return a human-readable representation of a frame.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        coding = None
 | 
					 | 
				
			||||||
        length = f"{len(self.data)} byte{'' if len(self.data) == 1 else 's'}"
 | 
					 | 
				
			||||||
        non_final = "" if self.fin else "continued"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if self.opcode is OP_TEXT:
 | 
					 | 
				
			||||||
            # Decoding only the beginning and the end is needlessly hard.
 | 
					 | 
				
			||||||
            # Decode the entire payload then elide later if necessary.
 | 
					 | 
				
			||||||
            data = repr(self.data.decode())
 | 
					 | 
				
			||||||
        elif self.opcode is OP_BINARY:
 | 
					 | 
				
			||||||
            # We'll show at most the first 16 bytes and the last 8 bytes.
 | 
					 | 
				
			||||||
            # Encode just what we need, plus two dummy bytes to elide later.
 | 
					 | 
				
			||||||
            binary = self.data
 | 
					 | 
				
			||||||
            if len(binary) > self.MAX_LOG_SIZE // 3:
 | 
					 | 
				
			||||||
                cut = (self.MAX_LOG_SIZE // 3 - 1) // 3  # by default cut = 8
 | 
					 | 
				
			||||||
                binary = b"".join([binary[: 2 * cut], b"\x00\x00", binary[-cut:]])
 | 
					 | 
				
			||||||
            data = " ".join(f"{byte:02x}" for byte in binary)
 | 
					 | 
				
			||||||
        elif self.opcode is OP_CLOSE:
 | 
					 | 
				
			||||||
            data = str(Close.parse(self.data))
 | 
					 | 
				
			||||||
        elif self.data:
 | 
					 | 
				
			||||||
            # We don't know if a Continuation frame contains text or binary.
 | 
					 | 
				
			||||||
            # Ping and Pong frames could contain UTF-8.
 | 
					 | 
				
			||||||
            # Attempt to decode as UTF-8 and display it as text; fallback to
 | 
					 | 
				
			||||||
            # binary. If self.data is a memoryview, it has no decode() method,
 | 
					 | 
				
			||||||
            # which raises AttributeError.
 | 
					 | 
				
			||||||
            try:
 | 
					 | 
				
			||||||
                data = repr(self.data.decode())
 | 
					 | 
				
			||||||
                coding = "text"
 | 
					 | 
				
			||||||
            except (UnicodeDecodeError, AttributeError):
 | 
					 | 
				
			||||||
                binary = self.data
 | 
					 | 
				
			||||||
                if len(binary) > self.MAX_LOG_SIZE // 3:
 | 
					 | 
				
			||||||
                    cut = (self.MAX_LOG_SIZE // 3 - 1) // 3  # by default cut = 8
 | 
					 | 
				
			||||||
                    binary = b"".join([binary[: 2 * cut], b"\x00\x00", binary[-cut:]])
 | 
					 | 
				
			||||||
                data = " ".join(f"{byte:02x}" for byte in binary)
 | 
					 | 
				
			||||||
                coding = "binary"
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            data = "''"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if len(data) > self.MAX_LOG_SIZE:
 | 
					 | 
				
			||||||
            cut = self.MAX_LOG_SIZE // 3 - 1  # by default cut = 24
 | 
					 | 
				
			||||||
            data = data[: 2 * cut] + "..." + data[-cut:]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        metadata = ", ".join(filter(None, [coding, length, non_final]))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return f"{self.opcode.name} {data} [{metadata}]"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @classmethod
 | 
					 | 
				
			||||||
    def parse(
 | 
					 | 
				
			||||||
        cls,
 | 
					 | 
				
			||||||
        read_exact: Callable[[int], Generator[None, None, bytes]],
 | 
					 | 
				
			||||||
        *,
 | 
					 | 
				
			||||||
        mask: bool,
 | 
					 | 
				
			||||||
        max_size: int | None = None,
 | 
					 | 
				
			||||||
        extensions: Sequence[extensions.Extension] | None = None,
 | 
					 | 
				
			||||||
    ) -> Generator[None, None, Frame]:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Parse a WebSocket frame.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        This is a generator-based coroutine.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Args:
 | 
					 | 
				
			||||||
            read_exact: Generator-based coroutine that reads the requested
 | 
					 | 
				
			||||||
                bytes or raises an exception if there isn't enough data.
 | 
					 | 
				
			||||||
            mask: Whether the frame should be masked i.e. whether the read
 | 
					 | 
				
			||||||
                happens on the server side.
 | 
					 | 
				
			||||||
            max_size: Maximum payload size in bytes.
 | 
					 | 
				
			||||||
            extensions: List of extensions, applied in reverse order.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Raises:
 | 
					 | 
				
			||||||
            EOFError: If the connection is closed without a full WebSocket frame.
 | 
					 | 
				
			||||||
            UnicodeDecodeError: If the frame contains invalid UTF-8.
 | 
					 | 
				
			||||||
            PayloadTooBig: If the frame's payload size exceeds ``max_size``.
 | 
					 | 
				
			||||||
            ProtocolError: If the frame contains incorrect values.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        # Read the header.
 | 
					 | 
				
			||||||
        data = yield from read_exact(2)
 | 
					 | 
				
			||||||
        head1, head2 = struct.unpack("!BB", data)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # While not Pythonic, this is marginally faster than calling bool().
 | 
					 | 
				
			||||||
        fin = True if head1 & 0b10000000 else False
 | 
					 | 
				
			||||||
        rsv1 = True if head1 & 0b01000000 else False
 | 
					 | 
				
			||||||
        rsv2 = True if head1 & 0b00100000 else False
 | 
					 | 
				
			||||||
        rsv3 = True if head1 & 0b00010000 else False
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        try:
 | 
					 | 
				
			||||||
            opcode = Opcode(head1 & 0b00001111)
 | 
					 | 
				
			||||||
        except ValueError as exc:
 | 
					 | 
				
			||||||
            raise ProtocolError("invalid opcode") from exc
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if (True if head2 & 0b10000000 else False) != mask:
 | 
					 | 
				
			||||||
            raise ProtocolError("incorrect masking")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        length = head2 & 0b01111111
 | 
					 | 
				
			||||||
        if length == 126:
 | 
					 | 
				
			||||||
            data = yield from read_exact(2)
 | 
					 | 
				
			||||||
            (length,) = struct.unpack("!H", data)
 | 
					 | 
				
			||||||
        elif length == 127:
 | 
					 | 
				
			||||||
            data = yield from read_exact(8)
 | 
					 | 
				
			||||||
            (length,) = struct.unpack("!Q", data)
 | 
					 | 
				
			||||||
        if max_size is not None and length > max_size:
 | 
					 | 
				
			||||||
            raise PayloadTooBig(f"over size limit ({length} > {max_size} bytes)")
 | 
					 | 
				
			||||||
        if mask:
 | 
					 | 
				
			||||||
            mask_bytes = yield from read_exact(4)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Read the data.
 | 
					 | 
				
			||||||
        data = yield from read_exact(length)
 | 
					 | 
				
			||||||
        if mask:
 | 
					 | 
				
			||||||
            data = apply_mask(data, mask_bytes)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        frame = cls(opcode, data, fin, rsv1, rsv2, rsv3)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if extensions is None:
 | 
					 | 
				
			||||||
            extensions = []
 | 
					 | 
				
			||||||
        for extension in reversed(extensions):
 | 
					 | 
				
			||||||
            frame = extension.decode(frame, max_size=max_size)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        frame.check()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return frame
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def serialize(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        *,
 | 
					 | 
				
			||||||
        mask: bool,
 | 
					 | 
				
			||||||
        extensions: Sequence[extensions.Extension] | None = None,
 | 
					 | 
				
			||||||
    ) -> bytes:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Serialize a WebSocket frame.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Args:
 | 
					 | 
				
			||||||
            mask: Whether the frame should be masked i.e. whether the write
 | 
					 | 
				
			||||||
                happens on the client side.
 | 
					 | 
				
			||||||
            extensions: List of extensions, applied in order.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Raises:
 | 
					 | 
				
			||||||
            ProtocolError: If the frame contains incorrect values.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        self.check()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if extensions is None:
 | 
					 | 
				
			||||||
            extensions = []
 | 
					 | 
				
			||||||
        for extension in extensions:
 | 
					 | 
				
			||||||
            self = extension.encode(self)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        output = io.BytesIO()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Prepare the header.
 | 
					 | 
				
			||||||
        head1 = (
 | 
					 | 
				
			||||||
            (0b10000000 if self.fin else 0)
 | 
					 | 
				
			||||||
            | (0b01000000 if self.rsv1 else 0)
 | 
					 | 
				
			||||||
            | (0b00100000 if self.rsv2 else 0)
 | 
					 | 
				
			||||||
            | (0b00010000 if self.rsv3 else 0)
 | 
					 | 
				
			||||||
            | self.opcode
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        head2 = 0b10000000 if mask else 0
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        length = len(self.data)
 | 
					 | 
				
			||||||
        if length < 126:
 | 
					 | 
				
			||||||
            output.write(struct.pack("!BB", head1, head2 | length))
 | 
					 | 
				
			||||||
        elif length < 65536:
 | 
					 | 
				
			||||||
            output.write(struct.pack("!BBH", head1, head2 | 126, length))
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            output.write(struct.pack("!BBQ", head1, head2 | 127, length))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if mask:
 | 
					 | 
				
			||||||
            mask_bytes = secrets.token_bytes(4)
 | 
					 | 
				
			||||||
            output.write(mask_bytes)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Prepare the data.
 | 
					 | 
				
			||||||
        if mask:
 | 
					 | 
				
			||||||
            data = apply_mask(self.data, mask_bytes)
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            data = self.data
 | 
					 | 
				
			||||||
        output.write(data)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return output.getvalue()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def check(self) -> None:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Check that reserved bits and opcode have acceptable values.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Raises:
 | 
					 | 
				
			||||||
            ProtocolError: If a reserved bit or the opcode is invalid.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        if self.rsv1 or self.rsv2 or self.rsv3:
 | 
					 | 
				
			||||||
            raise ProtocolError("reserved bits must be 0")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if self.opcode in CTRL_OPCODES:
 | 
					 | 
				
			||||||
            if len(self.data) > 125:
 | 
					 | 
				
			||||||
                raise ProtocolError("control frame too long")
 | 
					 | 
				
			||||||
            if not self.fin:
 | 
					 | 
				
			||||||
                raise ProtocolError("fragmented control frame")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@dataclasses.dataclass
 | 
					 | 
				
			||||||
class Close:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Code and reason for WebSocket close frames.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Attributes:
 | 
					 | 
				
			||||||
        code: Close code.
 | 
					 | 
				
			||||||
        reason: Close reason.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    code: int
 | 
					 | 
				
			||||||
    reason: str
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __str__(self) -> str:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Return a human-readable representation of a close code and reason.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        if 3000 <= self.code < 4000:
 | 
					 | 
				
			||||||
            explanation = "registered"
 | 
					 | 
				
			||||||
        elif 4000 <= self.code < 5000:
 | 
					 | 
				
			||||||
            explanation = "private use"
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            explanation = CLOSE_CODE_EXPLANATIONS.get(self.code, "unknown")
 | 
					 | 
				
			||||||
        result = f"{self.code} ({explanation})"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if self.reason:
 | 
					 | 
				
			||||||
            result = f"{result} {self.reason}"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return result
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @classmethod
 | 
					 | 
				
			||||||
    def parse(cls, data: bytes) -> Close:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Parse the payload of a close frame.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Args:
 | 
					 | 
				
			||||||
            data: Payload of the close frame.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Raises:
 | 
					 | 
				
			||||||
            ProtocolError: If data is ill-formed.
 | 
					 | 
				
			||||||
            UnicodeDecodeError: If the reason isn't valid UTF-8.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        if len(data) >= 2:
 | 
					 | 
				
			||||||
            (code,) = struct.unpack("!H", data[:2])
 | 
					 | 
				
			||||||
            reason = data[2:].decode()
 | 
					 | 
				
			||||||
            close = cls(code, reason)
 | 
					 | 
				
			||||||
            close.check()
 | 
					 | 
				
			||||||
            return close
 | 
					 | 
				
			||||||
        elif len(data) == 0:
 | 
					 | 
				
			||||||
            return cls(CloseCode.NO_STATUS_RCVD, "")
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            raise ProtocolError("close frame too short")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def serialize(self) -> bytes:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Serialize the payload of a close frame.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        self.check()
 | 
					 | 
				
			||||||
        return struct.pack("!H", self.code) + self.reason.encode()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def check(self) -> None:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Check that the close code has a valid value for a close frame.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Raises:
 | 
					 | 
				
			||||||
            ProtocolError: If the close code is invalid.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        if not (self.code in EXTERNAL_CLOSE_CODES or 3000 <= self.code < 5000):
 | 
					 | 
				
			||||||
            raise ProtocolError("invalid status code")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# At the bottom to break import cycles created by type annotations.
 | 
					 | 
				
			||||||
from . import extensions  # noqa: E402
 | 
					 | 
				
			||||||
@@ -1,579 +0,0 @@
 | 
				
			|||||||
from __future__ import annotations
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import base64
 | 
					 | 
				
			||||||
import binascii
 | 
					 | 
				
			||||||
import ipaddress
 | 
					 | 
				
			||||||
import re
 | 
					 | 
				
			||||||
from typing import Callable, Sequence, TypeVar, cast
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from .exceptions import InvalidHeaderFormat, InvalidHeaderValue
 | 
					 | 
				
			||||||
from .typing import (
 | 
					 | 
				
			||||||
    ConnectionOption,
 | 
					 | 
				
			||||||
    ExtensionHeader,
 | 
					 | 
				
			||||||
    ExtensionName,
 | 
					 | 
				
			||||||
    ExtensionParameter,
 | 
					 | 
				
			||||||
    Subprotocol,
 | 
					 | 
				
			||||||
    UpgradeProtocol,
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
__all__ = [
 | 
					 | 
				
			||||||
    "build_host",
 | 
					 | 
				
			||||||
    "parse_connection",
 | 
					 | 
				
			||||||
    "parse_upgrade",
 | 
					 | 
				
			||||||
    "parse_extension",
 | 
					 | 
				
			||||||
    "build_extension",
 | 
					 | 
				
			||||||
    "parse_subprotocol",
 | 
					 | 
				
			||||||
    "build_subprotocol",
 | 
					 | 
				
			||||||
    "validate_subprotocols",
 | 
					 | 
				
			||||||
    "build_www_authenticate_basic",
 | 
					 | 
				
			||||||
    "parse_authorization_basic",
 | 
					 | 
				
			||||||
    "build_authorization_basic",
 | 
					 | 
				
			||||||
]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
T = TypeVar("T")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def build_host(host: str, port: int, secure: bool) -> str:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Build a ``Host`` header.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    # https://datatracker.ietf.org/doc/html/rfc3986#section-3.2.2
 | 
					 | 
				
			||||||
    # IPv6 addresses must be enclosed in brackets.
 | 
					 | 
				
			||||||
    try:
 | 
					 | 
				
			||||||
        address = ipaddress.ip_address(host)
 | 
					 | 
				
			||||||
    except ValueError:
 | 
					 | 
				
			||||||
        # host is a hostname
 | 
					 | 
				
			||||||
        pass
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        # host is an IP address
 | 
					 | 
				
			||||||
        if address.version == 6:
 | 
					 | 
				
			||||||
            host = f"[{host}]"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if port != (443 if secure else 80):
 | 
					 | 
				
			||||||
        host = f"{host}:{port}"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    return host
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# To avoid a dependency on a parsing library, we implement manually the ABNF
 | 
					 | 
				
			||||||
# described in https://datatracker.ietf.org/doc/html/rfc6455#section-9.1 and
 | 
					 | 
				
			||||||
# https://datatracker.ietf.org/doc/html/rfc7230#appendix-B.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def peek_ahead(header: str, pos: int) -> str | None:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Return the next character from ``header`` at the given position.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Return :obj:`None` at the end of ``header``.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    We never need to peek more than one character ahead.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    return None if pos == len(header) else header[pos]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
_OWS_re = re.compile(r"[\t ]*")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def parse_OWS(header: str, pos: int) -> int:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Parse optional whitespace from ``header`` at the given position.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Return the new position.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    The whitespace itself isn't returned because it isn't significant.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    # There's always a match, possibly empty, whose content doesn't matter.
 | 
					 | 
				
			||||||
    match = _OWS_re.match(header, pos)
 | 
					 | 
				
			||||||
    assert match is not None
 | 
					 | 
				
			||||||
    return match.end()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
_token_re = re.compile(r"[-!#$%&\'*+.^_`|~0-9a-zA-Z]+")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def parse_token(header: str, pos: int, header_name: str) -> tuple[str, int]:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Parse a token from ``header`` at the given position.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Return the token value and the new position.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Raises:
 | 
					 | 
				
			||||||
        InvalidHeaderFormat: On invalid inputs.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    match = _token_re.match(header, pos)
 | 
					 | 
				
			||||||
    if match is None:
 | 
					 | 
				
			||||||
        raise InvalidHeaderFormat(header_name, "expected token", header, pos)
 | 
					 | 
				
			||||||
    return match.group(), match.end()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
_quoted_string_re = re.compile(
 | 
					 | 
				
			||||||
    r'"(?:[\x09\x20-\x21\x23-\x5b\x5d-\x7e]|\\[\x09\x20-\x7e\x80-\xff])*"'
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
_unquote_re = re.compile(r"\\([\x09\x20-\x7e\x80-\xff])")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def parse_quoted_string(header: str, pos: int, header_name: str) -> tuple[str, int]:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Parse a quoted string from ``header`` at the given position.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Return the unquoted value and the new position.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Raises:
 | 
					 | 
				
			||||||
        InvalidHeaderFormat: On invalid inputs.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    match = _quoted_string_re.match(header, pos)
 | 
					 | 
				
			||||||
    if match is None:
 | 
					 | 
				
			||||||
        raise InvalidHeaderFormat(header_name, "expected quoted string", header, pos)
 | 
					 | 
				
			||||||
    return _unquote_re.sub(r"\1", match.group()[1:-1]), match.end()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
_quotable_re = re.compile(r"[\x09\x20-\x7e\x80-\xff]*")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
_quote_re = re.compile(r"([\x22\x5c])")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def build_quoted_string(value: str) -> str:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Format ``value`` as a quoted string.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    This is the reverse of :func:`parse_quoted_string`.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    match = _quotable_re.fullmatch(value)
 | 
					 | 
				
			||||||
    if match is None:
 | 
					 | 
				
			||||||
        raise ValueError("invalid characters for quoted-string encoding")
 | 
					 | 
				
			||||||
    return '"' + _quote_re.sub(r"\\\1", value) + '"'
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def parse_list(
 | 
					 | 
				
			||||||
    parse_item: Callable[[str, int, str], tuple[T, int]],
 | 
					 | 
				
			||||||
    header: str,
 | 
					 | 
				
			||||||
    pos: int,
 | 
					 | 
				
			||||||
    header_name: str,
 | 
					 | 
				
			||||||
) -> list[T]:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Parse a comma-separated list from ``header`` at the given position.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    This is appropriate for parsing values with the following grammar:
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        1#item
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    ``parse_item`` parses one item.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    ``header`` is assumed not to start or end with whitespace.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    (This function is designed for parsing an entire header value and
 | 
					 | 
				
			||||||
    :func:`~websockets.http.read_headers` strips whitespace from values.)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Return a list of items.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Raises:
 | 
					 | 
				
			||||||
        InvalidHeaderFormat: On invalid inputs.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    # Per https://datatracker.ietf.org/doc/html/rfc7230#section-7, "a recipient
 | 
					 | 
				
			||||||
    # MUST parse and ignore a reasonable number of empty list elements";
 | 
					 | 
				
			||||||
    # hence while loops that remove extra delimiters.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Remove extra delimiters before the first item.
 | 
					 | 
				
			||||||
    while peek_ahead(header, pos) == ",":
 | 
					 | 
				
			||||||
        pos = parse_OWS(header, pos + 1)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    items = []
 | 
					 | 
				
			||||||
    while True:
 | 
					 | 
				
			||||||
        # Loop invariant: a item starts at pos in header.
 | 
					 | 
				
			||||||
        item, pos = parse_item(header, pos, header_name)
 | 
					 | 
				
			||||||
        items.append(item)
 | 
					 | 
				
			||||||
        pos = parse_OWS(header, pos)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # We may have reached the end of the header.
 | 
					 | 
				
			||||||
        if pos == len(header):
 | 
					 | 
				
			||||||
            break
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # There must be a delimiter after each element except the last one.
 | 
					 | 
				
			||||||
        if peek_ahead(header, pos) == ",":
 | 
					 | 
				
			||||||
            pos = parse_OWS(header, pos + 1)
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            raise InvalidHeaderFormat(header_name, "expected comma", header, pos)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Remove extra delimiters before the next item.
 | 
					 | 
				
			||||||
        while peek_ahead(header, pos) == ",":
 | 
					 | 
				
			||||||
            pos = parse_OWS(header, pos + 1)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # We may have reached the end of the header.
 | 
					 | 
				
			||||||
        if pos == len(header):
 | 
					 | 
				
			||||||
            break
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Since we only advance in the header by one character with peek_ahead()
 | 
					 | 
				
			||||||
    # or with the end position of a regex match, we can't overshoot the end.
 | 
					 | 
				
			||||||
    assert pos == len(header)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    return items
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def parse_connection_option(
 | 
					 | 
				
			||||||
    header: str, pos: int, header_name: str
 | 
					 | 
				
			||||||
) -> tuple[ConnectionOption, int]:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Parse a Connection option from ``header`` at the given position.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Return the protocol value and the new position.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Raises:
 | 
					 | 
				
			||||||
        InvalidHeaderFormat: On invalid inputs.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    item, pos = parse_token(header, pos, header_name)
 | 
					 | 
				
			||||||
    return cast(ConnectionOption, item), pos
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def parse_connection(header: str) -> list[ConnectionOption]:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Parse a ``Connection`` header.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Return a list of HTTP connection options.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Args
 | 
					 | 
				
			||||||
        header: value of the ``Connection`` header.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Raises:
 | 
					 | 
				
			||||||
        InvalidHeaderFormat: On invalid inputs.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    return parse_list(parse_connection_option, header, 0, "Connection")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
_protocol_re = re.compile(
 | 
					 | 
				
			||||||
    r"[-!#$%&\'*+.^_`|~0-9a-zA-Z]+(?:/[-!#$%&\'*+.^_`|~0-9a-zA-Z]+)?"
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def parse_upgrade_protocol(
 | 
					 | 
				
			||||||
    header: str, pos: int, header_name: str
 | 
					 | 
				
			||||||
) -> tuple[UpgradeProtocol, int]:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Parse an Upgrade protocol from ``header`` at the given position.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Return the protocol value and the new position.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Raises:
 | 
					 | 
				
			||||||
        InvalidHeaderFormat: On invalid inputs.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    match = _protocol_re.match(header, pos)
 | 
					 | 
				
			||||||
    if match is None:
 | 
					 | 
				
			||||||
        raise InvalidHeaderFormat(header_name, "expected protocol", header, pos)
 | 
					 | 
				
			||||||
    return cast(UpgradeProtocol, match.group()), match.end()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def parse_upgrade(header: str) -> list[UpgradeProtocol]:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Parse an ``Upgrade`` header.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Return a list of HTTP protocols.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Args:
 | 
					 | 
				
			||||||
        header: Value of the ``Upgrade`` header.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Raises:
 | 
					 | 
				
			||||||
        InvalidHeaderFormat: On invalid inputs.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    return parse_list(parse_upgrade_protocol, header, 0, "Upgrade")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def parse_extension_item_param(
 | 
					 | 
				
			||||||
    header: str, pos: int, header_name: str
 | 
					 | 
				
			||||||
) -> tuple[ExtensionParameter, int]:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Parse a single extension parameter from ``header`` at the given position.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Return a ``(name, value)`` pair and the new position.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Raises:
 | 
					 | 
				
			||||||
        InvalidHeaderFormat: On invalid inputs.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    # Extract parameter name.
 | 
					 | 
				
			||||||
    name, pos = parse_token(header, pos, header_name)
 | 
					 | 
				
			||||||
    pos = parse_OWS(header, pos)
 | 
					 | 
				
			||||||
    # Extract parameter value, if there is one.
 | 
					 | 
				
			||||||
    value: str | None = None
 | 
					 | 
				
			||||||
    if peek_ahead(header, pos) == "=":
 | 
					 | 
				
			||||||
        pos = parse_OWS(header, pos + 1)
 | 
					 | 
				
			||||||
        if peek_ahead(header, pos) == '"':
 | 
					 | 
				
			||||||
            pos_before = pos  # for proper error reporting below
 | 
					 | 
				
			||||||
            value, pos = parse_quoted_string(header, pos, header_name)
 | 
					 | 
				
			||||||
            # https://datatracker.ietf.org/doc/html/rfc6455#section-9.1 says:
 | 
					 | 
				
			||||||
            # the value after quoted-string unescaping MUST conform to
 | 
					 | 
				
			||||||
            # the 'token' ABNF.
 | 
					 | 
				
			||||||
            if _token_re.fullmatch(value) is None:
 | 
					 | 
				
			||||||
                raise InvalidHeaderFormat(
 | 
					 | 
				
			||||||
                    header_name, "invalid quoted header content", header, pos_before
 | 
					 | 
				
			||||||
                )
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            value, pos = parse_token(header, pos, header_name)
 | 
					 | 
				
			||||||
        pos = parse_OWS(header, pos)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    return (name, value), pos
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def parse_extension_item(
 | 
					 | 
				
			||||||
    header: str, pos: int, header_name: str
 | 
					 | 
				
			||||||
) -> tuple[ExtensionHeader, int]:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Parse an extension definition from ``header`` at the given position.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Return an ``(extension name, parameters)`` pair, where ``parameters`` is a
 | 
					 | 
				
			||||||
    list of ``(name, value)`` pairs, and the new position.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Raises:
 | 
					 | 
				
			||||||
        InvalidHeaderFormat: On invalid inputs.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    # Extract extension name.
 | 
					 | 
				
			||||||
    name, pos = parse_token(header, pos, header_name)
 | 
					 | 
				
			||||||
    pos = parse_OWS(header, pos)
 | 
					 | 
				
			||||||
    # Extract all parameters.
 | 
					 | 
				
			||||||
    parameters = []
 | 
					 | 
				
			||||||
    while peek_ahead(header, pos) == ";":
 | 
					 | 
				
			||||||
        pos = parse_OWS(header, pos + 1)
 | 
					 | 
				
			||||||
        parameter, pos = parse_extension_item_param(header, pos, header_name)
 | 
					 | 
				
			||||||
        parameters.append(parameter)
 | 
					 | 
				
			||||||
    return (cast(ExtensionName, name), parameters), pos
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def parse_extension(header: str) -> list[ExtensionHeader]:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Parse a ``Sec-WebSocket-Extensions`` header.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Return a list of WebSocket extensions and their parameters in this format::
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        [
 | 
					 | 
				
			||||||
            (
 | 
					 | 
				
			||||||
                'extension name',
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    ('parameter name', 'parameter value'),
 | 
					 | 
				
			||||||
                    ....
 | 
					 | 
				
			||||||
                ]
 | 
					 | 
				
			||||||
            ),
 | 
					 | 
				
			||||||
            ...
 | 
					 | 
				
			||||||
        ]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Parameter values are :obj:`None` when no value is provided.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Raises:
 | 
					 | 
				
			||||||
        InvalidHeaderFormat: On invalid inputs.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    return parse_list(parse_extension_item, header, 0, "Sec-WebSocket-Extensions")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
parse_extension_list = parse_extension  # alias for backwards compatibility
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def build_extension_item(
 | 
					 | 
				
			||||||
    name: ExtensionName, parameters: list[ExtensionParameter]
 | 
					 | 
				
			||||||
) -> str:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Build an extension definition.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    This is the reverse of :func:`parse_extension_item`.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    return "; ".join(
 | 
					 | 
				
			||||||
        [cast(str, name)]
 | 
					 | 
				
			||||||
        + [
 | 
					 | 
				
			||||||
            # Quoted strings aren't necessary because values are always tokens.
 | 
					 | 
				
			||||||
            name if value is None else f"{name}={value}"
 | 
					 | 
				
			||||||
            for name, value in parameters
 | 
					 | 
				
			||||||
        ]
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def build_extension(extensions: Sequence[ExtensionHeader]) -> str:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Build a ``Sec-WebSocket-Extensions`` header.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    This is the reverse of :func:`parse_extension`.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    return ", ".join(
 | 
					 | 
				
			||||||
        build_extension_item(name, parameters) for name, parameters in extensions
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
build_extension_list = build_extension  # alias for backwards compatibility
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def parse_subprotocol_item(
 | 
					 | 
				
			||||||
    header: str, pos: int, header_name: str
 | 
					 | 
				
			||||||
) -> tuple[Subprotocol, int]:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Parse a subprotocol from ``header`` at the given position.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Return the subprotocol value and the new position.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Raises:
 | 
					 | 
				
			||||||
        InvalidHeaderFormat: On invalid inputs.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    item, pos = parse_token(header, pos, header_name)
 | 
					 | 
				
			||||||
    return cast(Subprotocol, item), pos
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def parse_subprotocol(header: str) -> list[Subprotocol]:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Parse a ``Sec-WebSocket-Protocol`` header.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Return a list of WebSocket subprotocols.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Raises:
 | 
					 | 
				
			||||||
        InvalidHeaderFormat: On invalid inputs.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    return parse_list(parse_subprotocol_item, header, 0, "Sec-WebSocket-Protocol")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
parse_subprotocol_list = parse_subprotocol  # alias for backwards compatibility
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def build_subprotocol(subprotocols: Sequence[Subprotocol]) -> str:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Build a ``Sec-WebSocket-Protocol`` header.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    This is the reverse of :func:`parse_subprotocol`.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    return ", ".join(subprotocols)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
build_subprotocol_list = build_subprotocol  # alias for backwards compatibility
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def validate_subprotocols(subprotocols: Sequence[Subprotocol]) -> None:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Validate that ``subprotocols`` is suitable for :func:`build_subprotocol`.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    if not isinstance(subprotocols, Sequence):
 | 
					 | 
				
			||||||
        raise TypeError("subprotocols must be a list")
 | 
					 | 
				
			||||||
    if isinstance(subprotocols, str):
 | 
					 | 
				
			||||||
        raise TypeError("subprotocols must be a list, not a str")
 | 
					 | 
				
			||||||
    for subprotocol in subprotocols:
 | 
					 | 
				
			||||||
        if not _token_re.fullmatch(subprotocol):
 | 
					 | 
				
			||||||
            raise ValueError(f"invalid subprotocol: {subprotocol}")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def build_www_authenticate_basic(realm: str) -> str:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Build a ``WWW-Authenticate`` header for HTTP Basic Auth.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Args:
 | 
					 | 
				
			||||||
        realm: Identifier of the protection space.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    # https://datatracker.ietf.org/doc/html/rfc7617#section-2
 | 
					 | 
				
			||||||
    realm = build_quoted_string(realm)
 | 
					 | 
				
			||||||
    charset = build_quoted_string("UTF-8")
 | 
					 | 
				
			||||||
    return f"Basic realm={realm}, charset={charset}"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
_token68_re = re.compile(r"[A-Za-z0-9-._~+/]+=*")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def parse_token68(header: str, pos: int, header_name: str) -> tuple[str, int]:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Parse a token68 from ``header`` at the given position.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Return the token value and the new position.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Raises:
 | 
					 | 
				
			||||||
        InvalidHeaderFormat: On invalid inputs.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    match = _token68_re.match(header, pos)
 | 
					 | 
				
			||||||
    if match is None:
 | 
					 | 
				
			||||||
        raise InvalidHeaderFormat(header_name, "expected token68", header, pos)
 | 
					 | 
				
			||||||
    return match.group(), match.end()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def parse_end(header: str, pos: int, header_name: str) -> None:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Check that parsing reached the end of header.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    if pos < len(header):
 | 
					 | 
				
			||||||
        raise InvalidHeaderFormat(header_name, "trailing data", header, pos)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def parse_authorization_basic(header: str) -> tuple[str, str]:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Parse an ``Authorization`` header for HTTP Basic Auth.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Return a ``(username, password)`` tuple.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Args:
 | 
					 | 
				
			||||||
        header: Value of the ``Authorization`` header.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Raises:
 | 
					 | 
				
			||||||
        InvalidHeaderFormat: On invalid inputs.
 | 
					 | 
				
			||||||
        InvalidHeaderValue: On unsupported inputs.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    # https://datatracker.ietf.org/doc/html/rfc7235#section-2.1
 | 
					 | 
				
			||||||
    # https://datatracker.ietf.org/doc/html/rfc7617#section-2
 | 
					 | 
				
			||||||
    scheme, pos = parse_token(header, 0, "Authorization")
 | 
					 | 
				
			||||||
    if scheme.lower() != "basic":
 | 
					 | 
				
			||||||
        raise InvalidHeaderValue(
 | 
					 | 
				
			||||||
            "Authorization",
 | 
					 | 
				
			||||||
            f"unsupported scheme: {scheme}",
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
    if peek_ahead(header, pos) != " ":
 | 
					 | 
				
			||||||
        raise InvalidHeaderFormat(
 | 
					 | 
				
			||||||
            "Authorization", "expected space after scheme", header, pos
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
    pos += 1
 | 
					 | 
				
			||||||
    basic_credentials, pos = parse_token68(header, pos, "Authorization")
 | 
					 | 
				
			||||||
    parse_end(header, pos, "Authorization")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    try:
 | 
					 | 
				
			||||||
        user_pass = base64.b64decode(basic_credentials.encode()).decode()
 | 
					 | 
				
			||||||
    except binascii.Error:
 | 
					 | 
				
			||||||
        raise InvalidHeaderValue(
 | 
					 | 
				
			||||||
            "Authorization",
 | 
					 | 
				
			||||||
            "expected base64-encoded credentials",
 | 
					 | 
				
			||||||
        ) from None
 | 
					 | 
				
			||||||
    try:
 | 
					 | 
				
			||||||
        username, password = user_pass.split(":", 1)
 | 
					 | 
				
			||||||
    except ValueError:
 | 
					 | 
				
			||||||
        raise InvalidHeaderValue(
 | 
					 | 
				
			||||||
            "Authorization",
 | 
					 | 
				
			||||||
            "expected username:password credentials",
 | 
					 | 
				
			||||||
        ) from None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    return username, password
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def build_authorization_basic(username: str, password: str) -> str:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Build an ``Authorization`` header for HTTP Basic Auth.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    This is the reverse of :func:`parse_authorization_basic`.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    # https://datatracker.ietf.org/doc/html/rfc7617#section-2
 | 
					 | 
				
			||||||
    assert ":" not in username
 | 
					 | 
				
			||||||
    user_pass = f"{username}:{password}"
 | 
					 | 
				
			||||||
    basic_credentials = base64.b64encode(user_pass.encode()).decode()
 | 
					 | 
				
			||||||
    return "Basic " + basic_credentials
 | 
					 | 
				
			||||||
@@ -1,15 +0,0 @@
 | 
				
			|||||||
from __future__ import annotations
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import warnings
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from .datastructures import Headers, MultipleValuesError  # noqa: F401
 | 
					 | 
				
			||||||
from .legacy.http import read_request, read_response  # noqa: F401
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
warnings.warn(  # deprecated in 9.0 - 2021-09-01
 | 
					 | 
				
			||||||
    "Headers and MultipleValuesError were moved "
 | 
					 | 
				
			||||||
    "from websockets.http to websockets.datastructures"
 | 
					 | 
				
			||||||
    "and read_request and read_response were moved "
 | 
					 | 
				
			||||||
    "from websockets.http to websockets.legacy.http",
 | 
					 | 
				
			||||||
    DeprecationWarning,
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
@@ -1,385 +0,0 @@
 | 
				
			|||||||
from __future__ import annotations
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import dataclasses
 | 
					 | 
				
			||||||
import os
 | 
					 | 
				
			||||||
import re
 | 
					 | 
				
			||||||
import sys
 | 
					 | 
				
			||||||
import warnings
 | 
					 | 
				
			||||||
from typing import Callable, Generator
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from .datastructures import Headers
 | 
					 | 
				
			||||||
from .exceptions import SecurityError
 | 
					 | 
				
			||||||
from .version import version as websockets_version
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
__all__ = ["SERVER", "USER_AGENT", "Request", "Response"]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
PYTHON_VERSION = "{}.{}".format(*sys.version_info)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# User-Agent header for HTTP requests.
 | 
					 | 
				
			||||||
USER_AGENT = os.environ.get(
 | 
					 | 
				
			||||||
    "WEBSOCKETS_USER_AGENT",
 | 
					 | 
				
			||||||
    f"Python/{PYTHON_VERSION} websockets/{websockets_version}",
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# Server header for HTTP responses.
 | 
					 | 
				
			||||||
SERVER = os.environ.get(
 | 
					 | 
				
			||||||
    "WEBSOCKETS_SERVER",
 | 
					 | 
				
			||||||
    f"Python/{PYTHON_VERSION} websockets/{websockets_version}",
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# Maximum total size of headers is around 128 * 8 KiB = 1 MiB.
 | 
					 | 
				
			||||||
MAX_NUM_HEADERS = int(os.environ.get("WEBSOCKETS_MAX_NUM_HEADERS", "128"))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# Limit request line and header lines. 8KiB is the most common default
 | 
					 | 
				
			||||||
# configuration of popular HTTP servers.
 | 
					 | 
				
			||||||
MAX_LINE_LENGTH = int(os.environ.get("WEBSOCKETS_MAX_LINE_LENGTH", "8192"))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# Support for HTTP response bodies is intended to read an error message
 | 
					 | 
				
			||||||
# returned by a server. It isn't designed to perform large file transfers.
 | 
					 | 
				
			||||||
MAX_BODY_SIZE = int(os.environ.get("WEBSOCKETS_MAX_BODY_SIZE", "1_048_576"))  # 1 MiB
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def d(value: bytes) -> str:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Decode a bytestring for interpolating into an error message.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    return value.decode(errors="backslashreplace")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# See https://datatracker.ietf.org/doc/html/rfc7230#appendix-B.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# Regex for validating header names.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
_token_re = re.compile(rb"[-!#$%&\'*+.^_`|~0-9a-zA-Z]+")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# Regex for validating header values.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# We don't attempt to support obsolete line folding.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# Include HTAB (\x09), SP (\x20), VCHAR (\x21-\x7e), obs-text (\x80-\xff).
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# The ABNF is complicated because it attempts to express that optional
 | 
					 | 
				
			||||||
# whitespace is ignored. We strip whitespace and don't revalidate that.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# See also https://www.rfc-editor.org/errata_search.php?rfc=7230&eid=4189
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
_value_re = re.compile(rb"[\x09\x20-\x7e\x80-\xff]*")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@dataclasses.dataclass
 | 
					 | 
				
			||||||
class Request:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    WebSocket handshake request.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Attributes:
 | 
					 | 
				
			||||||
        path: Request path, including optional query.
 | 
					 | 
				
			||||||
        headers: Request headers.
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    path: str
 | 
					 | 
				
			||||||
    headers: Headers
 | 
					 | 
				
			||||||
    # body isn't useful is the context of this library.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    _exception: Exception | None = None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @property
 | 
					 | 
				
			||||||
    def exception(self) -> Exception | None:  # pragma: no cover
 | 
					 | 
				
			||||||
        warnings.warn(  # deprecated in 10.3 - 2022-04-17
 | 
					 | 
				
			||||||
            "Request.exception is deprecated; "
 | 
					 | 
				
			||||||
            "use ServerProtocol.handshake_exc instead",
 | 
					 | 
				
			||||||
            DeprecationWarning,
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        return self._exception
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @classmethod
 | 
					 | 
				
			||||||
    def parse(
 | 
					 | 
				
			||||||
        cls,
 | 
					 | 
				
			||||||
        read_line: Callable[[int], Generator[None, None, bytes]],
 | 
					 | 
				
			||||||
    ) -> Generator[None, None, Request]:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Parse a WebSocket handshake request.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        This is a generator-based coroutine.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        The request path isn't URL-decoded or validated in any way.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        The request path and headers are expected to contain only ASCII
 | 
					 | 
				
			||||||
        characters. Other characters are represented with surrogate escapes.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        :meth:`parse` doesn't attempt to read the request body because
 | 
					 | 
				
			||||||
        WebSocket handshake requests don't have one. If the request contains a
 | 
					 | 
				
			||||||
        body, it may be read from the data stream after :meth:`parse` returns.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Args:
 | 
					 | 
				
			||||||
            read_line: Generator-based coroutine that reads a LF-terminated
 | 
					 | 
				
			||||||
                line or raises an exception if there isn't enough data
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Raises:
 | 
					 | 
				
			||||||
            EOFError: If the connection is closed without a full HTTP request.
 | 
					 | 
				
			||||||
            SecurityError: If the request exceeds a security limit.
 | 
					 | 
				
			||||||
            ValueError: If the request isn't well formatted.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        # https://datatracker.ietf.org/doc/html/rfc7230#section-3.1.1
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Parsing is simple because fixed values are expected for method and
 | 
					 | 
				
			||||||
        # version and because path isn't checked. Since WebSocket software tends
 | 
					 | 
				
			||||||
        # to implement HTTP/1.1 strictly, there's little need for lenient parsing.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        try:
 | 
					 | 
				
			||||||
            request_line = yield from parse_line(read_line)
 | 
					 | 
				
			||||||
        except EOFError as exc:
 | 
					 | 
				
			||||||
            raise EOFError("connection closed while reading HTTP request line") from exc
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        try:
 | 
					 | 
				
			||||||
            method, raw_path, protocol = request_line.split(b" ", 2)
 | 
					 | 
				
			||||||
        except ValueError:  # not enough values to unpack (expected 3, got 1-2)
 | 
					 | 
				
			||||||
            raise ValueError(f"invalid HTTP request line: {d(request_line)}") from None
 | 
					 | 
				
			||||||
        if protocol != b"HTTP/1.1":
 | 
					 | 
				
			||||||
            raise ValueError(
 | 
					 | 
				
			||||||
                f"unsupported protocol; expected HTTP/1.1: {d(request_line)}"
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
        if method != b"GET":
 | 
					 | 
				
			||||||
            raise ValueError(f"unsupported HTTP method; expected GET; got {d(method)}")
 | 
					 | 
				
			||||||
        path = raw_path.decode("ascii", "surrogateescape")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        headers = yield from parse_headers(read_line)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # https://datatracker.ietf.org/doc/html/rfc7230#section-3.3.3
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if "Transfer-Encoding" in headers:
 | 
					 | 
				
			||||||
            raise NotImplementedError("transfer codings aren't supported")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if "Content-Length" in headers:
 | 
					 | 
				
			||||||
            raise ValueError("unsupported request body")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return cls(path, headers)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def serialize(self) -> bytes:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Serialize a WebSocket handshake request.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        # Since the request line and headers only contain ASCII characters,
 | 
					 | 
				
			||||||
        # we can keep this simple.
 | 
					 | 
				
			||||||
        request = f"GET {self.path} HTTP/1.1\r\n".encode()
 | 
					 | 
				
			||||||
        request += self.headers.serialize()
 | 
					 | 
				
			||||||
        return request
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@dataclasses.dataclass
 | 
					 | 
				
			||||||
class Response:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    WebSocket handshake response.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Attributes:
 | 
					 | 
				
			||||||
        status_code: Response code.
 | 
					 | 
				
			||||||
        reason_phrase: Response reason.
 | 
					 | 
				
			||||||
        headers: Response headers.
 | 
					 | 
				
			||||||
        body: Response body, if any.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    status_code: int
 | 
					 | 
				
			||||||
    reason_phrase: str
 | 
					 | 
				
			||||||
    headers: Headers
 | 
					 | 
				
			||||||
    body: bytes | None = None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    _exception: Exception | None = None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @property
 | 
					 | 
				
			||||||
    def exception(self) -> Exception | None:  # pragma: no cover
 | 
					 | 
				
			||||||
        warnings.warn(  # deprecated in 10.3 - 2022-04-17
 | 
					 | 
				
			||||||
            "Response.exception is deprecated; "
 | 
					 | 
				
			||||||
            "use ClientProtocol.handshake_exc instead",
 | 
					 | 
				
			||||||
            DeprecationWarning,
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        return self._exception
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @classmethod
 | 
					 | 
				
			||||||
    def parse(
 | 
					 | 
				
			||||||
        cls,
 | 
					 | 
				
			||||||
        read_line: Callable[[int], Generator[None, None, bytes]],
 | 
					 | 
				
			||||||
        read_exact: Callable[[int], Generator[None, None, bytes]],
 | 
					 | 
				
			||||||
        read_to_eof: Callable[[int], Generator[None, None, bytes]],
 | 
					 | 
				
			||||||
    ) -> Generator[None, None, Response]:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Parse a WebSocket handshake response.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        This is a generator-based coroutine.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        The reason phrase and headers are expected to contain only ASCII
 | 
					 | 
				
			||||||
        characters. Other characters are represented with surrogate escapes.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Args:
 | 
					 | 
				
			||||||
            read_line: Generator-based coroutine that reads a LF-terminated
 | 
					 | 
				
			||||||
                line or raises an exception if there isn't enough data.
 | 
					 | 
				
			||||||
            read_exact: Generator-based coroutine that reads the requested
 | 
					 | 
				
			||||||
                bytes or raises an exception if there isn't enough data.
 | 
					 | 
				
			||||||
            read_to_eof: Generator-based coroutine that reads until the end
 | 
					 | 
				
			||||||
                of the stream.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Raises:
 | 
					 | 
				
			||||||
            EOFError: If the connection is closed without a full HTTP response.
 | 
					 | 
				
			||||||
            SecurityError: If the response exceeds a security limit.
 | 
					 | 
				
			||||||
            LookupError: If the response isn't well formatted.
 | 
					 | 
				
			||||||
            ValueError: If the response isn't well formatted.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        # https://datatracker.ietf.org/doc/html/rfc7230#section-3.1.2
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        try:
 | 
					 | 
				
			||||||
            status_line = yield from parse_line(read_line)
 | 
					 | 
				
			||||||
        except EOFError as exc:
 | 
					 | 
				
			||||||
            raise EOFError("connection closed while reading HTTP status line") from exc
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        try:
 | 
					 | 
				
			||||||
            protocol, raw_status_code, raw_reason = status_line.split(b" ", 2)
 | 
					 | 
				
			||||||
        except ValueError:  # not enough values to unpack (expected 3, got 1-2)
 | 
					 | 
				
			||||||
            raise ValueError(f"invalid HTTP status line: {d(status_line)}") from None
 | 
					 | 
				
			||||||
        if protocol != b"HTTP/1.1":
 | 
					 | 
				
			||||||
            raise ValueError(
 | 
					 | 
				
			||||||
                f"unsupported protocol; expected HTTP/1.1: {d(status_line)}"
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
        try:
 | 
					 | 
				
			||||||
            status_code = int(raw_status_code)
 | 
					 | 
				
			||||||
        except ValueError:  # invalid literal for int() with base 10
 | 
					 | 
				
			||||||
            raise ValueError(
 | 
					 | 
				
			||||||
                f"invalid status code; expected integer; got {d(raw_status_code)}"
 | 
					 | 
				
			||||||
            ) from None
 | 
					 | 
				
			||||||
        if not 100 <= status_code < 600:
 | 
					 | 
				
			||||||
            raise ValueError(
 | 
					 | 
				
			||||||
                f"invalid status code; expected 100–599; got {d(raw_status_code)}"
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
        if not _value_re.fullmatch(raw_reason):
 | 
					 | 
				
			||||||
            raise ValueError(f"invalid HTTP reason phrase: {d(raw_reason)}")
 | 
					 | 
				
			||||||
        reason = raw_reason.decode("ascii", "surrogateescape")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        headers = yield from parse_headers(read_line)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # https://datatracker.ietf.org/doc/html/rfc7230#section-3.3.3
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if "Transfer-Encoding" in headers:
 | 
					 | 
				
			||||||
            raise NotImplementedError("transfer codings aren't supported")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Since websockets only does GET requests (no HEAD, no CONNECT), all
 | 
					 | 
				
			||||||
        # responses except 1xx, 204, and 304 include a message body.
 | 
					 | 
				
			||||||
        if 100 <= status_code < 200 or status_code == 204 or status_code == 304:
 | 
					 | 
				
			||||||
            body = None
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            content_length: int | None
 | 
					 | 
				
			||||||
            try:
 | 
					 | 
				
			||||||
                # MultipleValuesError is sufficiently unlikely that we don't
 | 
					 | 
				
			||||||
                # attempt to handle it. Instead we document that its parent
 | 
					 | 
				
			||||||
                # class, LookupError, may be raised.
 | 
					 | 
				
			||||||
                raw_content_length = headers["Content-Length"]
 | 
					 | 
				
			||||||
            except KeyError:
 | 
					 | 
				
			||||||
                content_length = None
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                content_length = int(raw_content_length)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            if content_length is None:
 | 
					 | 
				
			||||||
                try:
 | 
					 | 
				
			||||||
                    body = yield from read_to_eof(MAX_BODY_SIZE)
 | 
					 | 
				
			||||||
                except RuntimeError:
 | 
					 | 
				
			||||||
                    raise SecurityError(f"body too large: over {MAX_BODY_SIZE} bytes")
 | 
					 | 
				
			||||||
            elif content_length > MAX_BODY_SIZE:
 | 
					 | 
				
			||||||
                raise SecurityError(f"body too large: {content_length} bytes")
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                body = yield from read_exact(content_length)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return cls(status_code, reason, headers, body)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def serialize(self) -> bytes:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Serialize a WebSocket handshake response.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        # Since the status line and headers only contain ASCII characters,
 | 
					 | 
				
			||||||
        # we can keep this simple.
 | 
					 | 
				
			||||||
        response = f"HTTP/1.1 {self.status_code} {self.reason_phrase}\r\n".encode()
 | 
					 | 
				
			||||||
        response += self.headers.serialize()
 | 
					 | 
				
			||||||
        if self.body is not None:
 | 
					 | 
				
			||||||
            response += self.body
 | 
					 | 
				
			||||||
        return response
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def parse_headers(
 | 
					 | 
				
			||||||
    read_line: Callable[[int], Generator[None, None, bytes]],
 | 
					 | 
				
			||||||
) -> Generator[None, None, Headers]:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Parse HTTP headers.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Non-ASCII characters are represented with surrogate escapes.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Args:
 | 
					 | 
				
			||||||
        read_line: Generator-based coroutine that reads a LF-terminated line
 | 
					 | 
				
			||||||
            or raises an exception if there isn't enough data.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Raises:
 | 
					 | 
				
			||||||
        EOFError: If the connection is closed without complete headers.
 | 
					 | 
				
			||||||
        SecurityError: If the request exceeds a security limit.
 | 
					 | 
				
			||||||
        ValueError: If the request isn't well formatted.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    # https://datatracker.ietf.org/doc/html/rfc7230#section-3.2
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # We don't attempt to support obsolete line folding.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    headers = Headers()
 | 
					 | 
				
			||||||
    for _ in range(MAX_NUM_HEADERS + 1):
 | 
					 | 
				
			||||||
        try:
 | 
					 | 
				
			||||||
            line = yield from parse_line(read_line)
 | 
					 | 
				
			||||||
        except EOFError as exc:
 | 
					 | 
				
			||||||
            raise EOFError("connection closed while reading HTTP headers") from exc
 | 
					 | 
				
			||||||
        if line == b"":
 | 
					 | 
				
			||||||
            break
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        try:
 | 
					 | 
				
			||||||
            raw_name, raw_value = line.split(b":", 1)
 | 
					 | 
				
			||||||
        except ValueError:  # not enough values to unpack (expected 2, got 1)
 | 
					 | 
				
			||||||
            raise ValueError(f"invalid HTTP header line: {d(line)}") from None
 | 
					 | 
				
			||||||
        if not _token_re.fullmatch(raw_name):
 | 
					 | 
				
			||||||
            raise ValueError(f"invalid HTTP header name: {d(raw_name)}")
 | 
					 | 
				
			||||||
        raw_value = raw_value.strip(b" \t")
 | 
					 | 
				
			||||||
        if not _value_re.fullmatch(raw_value):
 | 
					 | 
				
			||||||
            raise ValueError(f"invalid HTTP header value: {d(raw_value)}")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        name = raw_name.decode("ascii")  # guaranteed to be ASCII at this point
 | 
					 | 
				
			||||||
        value = raw_value.decode("ascii", "surrogateescape")
 | 
					 | 
				
			||||||
        headers[name] = value
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        raise SecurityError("too many HTTP headers")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    return headers
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def parse_line(
 | 
					 | 
				
			||||||
    read_line: Callable[[int], Generator[None, None, bytes]],
 | 
					 | 
				
			||||||
) -> Generator[None, None, bytes]:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Parse a single line.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    CRLF is stripped from the return value.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Args:
 | 
					 | 
				
			||||||
        read_line: Generator-based coroutine that reads a LF-terminated line
 | 
					 | 
				
			||||||
            or raises an exception if there isn't enough data.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Raises:
 | 
					 | 
				
			||||||
        EOFError: If the connection is closed without a CRLF.
 | 
					 | 
				
			||||||
        SecurityError: If the response exceeds a security limit.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    try:
 | 
					 | 
				
			||||||
        line = yield from read_line(MAX_LINE_LENGTH)
 | 
					 | 
				
			||||||
    except RuntimeError:
 | 
					 | 
				
			||||||
        raise SecurityError("line too long")
 | 
					 | 
				
			||||||
    # Not mandatory but safe - https://datatracker.ietf.org/doc/html/rfc7230#section-3.5
 | 
					 | 
				
			||||||
    if not line.endswith(b"\r\n"):
 | 
					 | 
				
			||||||
        raise EOFError("line without CRLF")
 | 
					 | 
				
			||||||
    return line[:-2]
 | 
					 | 
				
			||||||
@@ -1,99 +0,0 @@
 | 
				
			|||||||
from __future__ import annotations
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import warnings
 | 
					 | 
				
			||||||
from typing import Any, Iterable
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
__all__ = ["lazy_import"]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def import_name(name: str, source: str, namespace: dict[str, Any]) -> Any:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Import ``name`` from ``source`` in ``namespace``.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    There are two use cases:
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    - ``name`` is an object defined in ``source``;
 | 
					 | 
				
			||||||
    - ``name`` is a submodule of ``source``.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Neither :func:`__import__` nor :func:`~importlib.import_module` does
 | 
					 | 
				
			||||||
    exactly this. :func:`__import__` is closer to the intended behavior.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    level = 0
 | 
					 | 
				
			||||||
    while source[level] == ".":
 | 
					 | 
				
			||||||
        level += 1
 | 
					 | 
				
			||||||
        assert level < len(source), "importing from parent isn't supported"
 | 
					 | 
				
			||||||
    module = __import__(source[level:], namespace, None, [name], level)
 | 
					 | 
				
			||||||
    return getattr(module, name)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def lazy_import(
 | 
					 | 
				
			||||||
    namespace: dict[str, Any],
 | 
					 | 
				
			||||||
    aliases: dict[str, str] | None = None,
 | 
					 | 
				
			||||||
    deprecated_aliases: dict[str, str] | None = None,
 | 
					 | 
				
			||||||
) -> None:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Provide lazy, module-level imports.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Typical use::
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        __getattr__, __dir__ = lazy_import(
 | 
					 | 
				
			||||||
            globals(),
 | 
					 | 
				
			||||||
            aliases={
 | 
					 | 
				
			||||||
                "<name>": "<source module>",
 | 
					 | 
				
			||||||
                ...
 | 
					 | 
				
			||||||
            },
 | 
					 | 
				
			||||||
            deprecated_aliases={
 | 
					 | 
				
			||||||
                ...,
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    This function defines ``__getattr__`` and ``__dir__`` per :pep:`562`.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    if aliases is None:
 | 
					 | 
				
			||||||
        aliases = {}
 | 
					 | 
				
			||||||
    if deprecated_aliases is None:
 | 
					 | 
				
			||||||
        deprecated_aliases = {}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    namespace_set = set(namespace)
 | 
					 | 
				
			||||||
    aliases_set = set(aliases)
 | 
					 | 
				
			||||||
    deprecated_aliases_set = set(deprecated_aliases)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    assert not namespace_set & aliases_set, "namespace conflict"
 | 
					 | 
				
			||||||
    assert not namespace_set & deprecated_aliases_set, "namespace conflict"
 | 
					 | 
				
			||||||
    assert not aliases_set & deprecated_aliases_set, "namespace conflict"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    package = namespace["__name__"]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __getattr__(name: str) -> Any:
 | 
					 | 
				
			||||||
        assert aliases is not None  # mypy cannot figure this out
 | 
					 | 
				
			||||||
        try:
 | 
					 | 
				
			||||||
            source = aliases[name]
 | 
					 | 
				
			||||||
        except KeyError:
 | 
					 | 
				
			||||||
            pass
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            return import_name(name, source, namespace)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        assert deprecated_aliases is not None  # mypy cannot figure this out
 | 
					 | 
				
			||||||
        try:
 | 
					 | 
				
			||||||
            source = deprecated_aliases[name]
 | 
					 | 
				
			||||||
        except KeyError:
 | 
					 | 
				
			||||||
            pass
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            warnings.warn(
 | 
					 | 
				
			||||||
                f"{package}.{name} is deprecated",
 | 
					 | 
				
			||||||
                DeprecationWarning,
 | 
					 | 
				
			||||||
                stacklevel=2,
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
            return import_name(name, source, namespace)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        raise AttributeError(f"module {package!r} has no attribute {name!r}")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    namespace["__getattr__"] = __getattr__
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __dir__() -> Iterable[str]:
 | 
					 | 
				
			||||||
        return sorted(namespace_set | aliases_set | deprecated_aliases_set)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    namespace["__dir__"] = __dir__
 | 
					 | 
				
			||||||
@@ -1,190 +0,0 @@
 | 
				
			|||||||
from __future__ import annotations
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import functools
 | 
					 | 
				
			||||||
import hmac
 | 
					 | 
				
			||||||
import http
 | 
					 | 
				
			||||||
from typing import Any, Awaitable, Callable, Iterable, Tuple, cast
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from ..datastructures import Headers
 | 
					 | 
				
			||||||
from ..exceptions import InvalidHeader
 | 
					 | 
				
			||||||
from ..headers import build_www_authenticate_basic, parse_authorization_basic
 | 
					 | 
				
			||||||
from .server import HTTPResponse, WebSocketServerProtocol
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
__all__ = ["BasicAuthWebSocketServerProtocol", "basic_auth_protocol_factory"]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# Change to tuple[str, str] when dropping Python < 3.9.
 | 
					 | 
				
			||||||
Credentials = Tuple[str, str]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def is_credentials(value: Any) -> bool:
 | 
					 | 
				
			||||||
    try:
 | 
					 | 
				
			||||||
        username, password = value
 | 
					 | 
				
			||||||
    except (TypeError, ValueError):
 | 
					 | 
				
			||||||
        return False
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        return isinstance(username, str) and isinstance(password, str)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class BasicAuthWebSocketServerProtocol(WebSocketServerProtocol):
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    WebSocket server protocol that enforces HTTP Basic Auth.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    realm: str = ""
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Scope of protection.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    If provided, it should contain only ASCII characters because the
 | 
					 | 
				
			||||||
    encoding of non-ASCII characters is undefined.
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    username: str | None = None
 | 
					 | 
				
			||||||
    """Username of the authenticated user."""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        *args: Any,
 | 
					 | 
				
			||||||
        realm: str | None = None,
 | 
					 | 
				
			||||||
        check_credentials: Callable[[str, str], Awaitable[bool]] | None = None,
 | 
					 | 
				
			||||||
        **kwargs: Any,
 | 
					 | 
				
			||||||
    ) -> None:
 | 
					 | 
				
			||||||
        if realm is not None:
 | 
					 | 
				
			||||||
            self.realm = realm  # shadow class attribute
 | 
					 | 
				
			||||||
        self._check_credentials = check_credentials
 | 
					 | 
				
			||||||
        super().__init__(*args, **kwargs)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    async def check_credentials(self, username: str, password: str) -> bool:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Check whether credentials are authorized.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        This coroutine may be overridden in a subclass, for example to
 | 
					 | 
				
			||||||
        authenticate against a database or an external service.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Args:
 | 
					 | 
				
			||||||
            username: HTTP Basic Auth username.
 | 
					 | 
				
			||||||
            password: HTTP Basic Auth password.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Returns:
 | 
					 | 
				
			||||||
            :obj:`True` if the handshake should continue;
 | 
					 | 
				
			||||||
            :obj:`False` if it should fail with an HTTP 401 error.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        if self._check_credentials is not None:
 | 
					 | 
				
			||||||
            return await self._check_credentials(username, password)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return False
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    async def process_request(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        path: str,
 | 
					 | 
				
			||||||
        request_headers: Headers,
 | 
					 | 
				
			||||||
    ) -> HTTPResponse | None:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Check HTTP Basic Auth and return an HTTP 401 response if needed.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        try:
 | 
					 | 
				
			||||||
            authorization = request_headers["Authorization"]
 | 
					 | 
				
			||||||
        except KeyError:
 | 
					 | 
				
			||||||
            return (
 | 
					 | 
				
			||||||
                http.HTTPStatus.UNAUTHORIZED,
 | 
					 | 
				
			||||||
                [("WWW-Authenticate", build_www_authenticate_basic(self.realm))],
 | 
					 | 
				
			||||||
                b"Missing credentials\n",
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        try:
 | 
					 | 
				
			||||||
            username, password = parse_authorization_basic(authorization)
 | 
					 | 
				
			||||||
        except InvalidHeader:
 | 
					 | 
				
			||||||
            return (
 | 
					 | 
				
			||||||
                http.HTTPStatus.UNAUTHORIZED,
 | 
					 | 
				
			||||||
                [("WWW-Authenticate", build_www_authenticate_basic(self.realm))],
 | 
					 | 
				
			||||||
                b"Unsupported credentials\n",
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if not await self.check_credentials(username, password):
 | 
					 | 
				
			||||||
            return (
 | 
					 | 
				
			||||||
                http.HTTPStatus.UNAUTHORIZED,
 | 
					 | 
				
			||||||
                [("WWW-Authenticate", build_www_authenticate_basic(self.realm))],
 | 
					 | 
				
			||||||
                b"Invalid credentials\n",
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.username = username
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return await super().process_request(path, request_headers)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def basic_auth_protocol_factory(
 | 
					 | 
				
			||||||
    realm: str | None = None,
 | 
					 | 
				
			||||||
    credentials: Credentials | Iterable[Credentials] | None = None,
 | 
					 | 
				
			||||||
    check_credentials: Callable[[str, str], Awaitable[bool]] | None = None,
 | 
					 | 
				
			||||||
    create_protocol: Callable[..., BasicAuthWebSocketServerProtocol] | None = None,
 | 
					 | 
				
			||||||
) -> Callable[..., BasicAuthWebSocketServerProtocol]:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Protocol factory that enforces HTTP Basic Auth.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    :func:`basic_auth_protocol_factory` is designed to integrate with
 | 
					 | 
				
			||||||
    :func:`~websockets.legacy.server.serve` like this::
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        serve(
 | 
					 | 
				
			||||||
            ...,
 | 
					 | 
				
			||||||
            create_protocol=basic_auth_protocol_factory(
 | 
					 | 
				
			||||||
                realm="my dev server",
 | 
					 | 
				
			||||||
                credentials=("hello", "iloveyou"),
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Args:
 | 
					 | 
				
			||||||
        realm: Scope of protection. It should contain only ASCII characters
 | 
					 | 
				
			||||||
            because the encoding of non-ASCII characters is undefined.
 | 
					 | 
				
			||||||
            Refer to section 2.2 of :rfc:`7235` for details.
 | 
					 | 
				
			||||||
        credentials: Hard coded authorized credentials. It can be a
 | 
					 | 
				
			||||||
            ``(username, password)`` pair or a list of such pairs.
 | 
					 | 
				
			||||||
        check_credentials: Coroutine that verifies credentials.
 | 
					 | 
				
			||||||
            It receives ``username`` and ``password`` arguments
 | 
					 | 
				
			||||||
            and returns a :class:`bool`. One of ``credentials`` or
 | 
					 | 
				
			||||||
            ``check_credentials`` must be provided but not both.
 | 
					 | 
				
			||||||
        create_protocol: Factory that creates the protocol. By default, this
 | 
					 | 
				
			||||||
            is :class:`BasicAuthWebSocketServerProtocol`. It can be replaced
 | 
					 | 
				
			||||||
            by a subclass.
 | 
					 | 
				
			||||||
    Raises:
 | 
					 | 
				
			||||||
        TypeError: If the ``credentials`` or ``check_credentials`` argument is
 | 
					 | 
				
			||||||
            wrong.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    if (credentials is None) == (check_credentials is None):
 | 
					 | 
				
			||||||
        raise TypeError("provide either credentials or check_credentials")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if credentials is not None:
 | 
					 | 
				
			||||||
        if is_credentials(credentials):
 | 
					 | 
				
			||||||
            credentials_list = [cast(Credentials, credentials)]
 | 
					 | 
				
			||||||
        elif isinstance(credentials, Iterable):
 | 
					 | 
				
			||||||
            credentials_list = list(cast(Iterable[Credentials], credentials))
 | 
					 | 
				
			||||||
            if not all(is_credentials(item) for item in credentials_list):
 | 
					 | 
				
			||||||
                raise TypeError(f"invalid credentials argument: {credentials}")
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            raise TypeError(f"invalid credentials argument: {credentials}")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        credentials_dict = dict(credentials_list)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        async def check_credentials(username: str, password: str) -> bool:
 | 
					 | 
				
			||||||
            try:
 | 
					 | 
				
			||||||
                expected_password = credentials_dict[username]
 | 
					 | 
				
			||||||
            except KeyError:
 | 
					 | 
				
			||||||
                return False
 | 
					 | 
				
			||||||
            return hmac.compare_digest(expected_password, password)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if create_protocol is None:
 | 
					 | 
				
			||||||
        create_protocol = BasicAuthWebSocketServerProtocol
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Help mypy and avoid this error: "type[BasicAuthWebSocketServerProtocol] |
 | 
					 | 
				
			||||||
    # Callable[..., BasicAuthWebSocketServerProtocol]" not callable  [misc]
 | 
					 | 
				
			||||||
    create_protocol = cast(
 | 
					 | 
				
			||||||
        Callable[..., BasicAuthWebSocketServerProtocol], create_protocol
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
    return functools.partial(
 | 
					 | 
				
			||||||
        create_protocol,
 | 
					 | 
				
			||||||
        realm=realm,
 | 
					 | 
				
			||||||
        check_credentials=check_credentials,
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
@@ -1,707 +0,0 @@
 | 
				
			|||||||
from __future__ import annotations
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import asyncio
 | 
					 | 
				
			||||||
import functools
 | 
					 | 
				
			||||||
import logging
 | 
					 | 
				
			||||||
import os
 | 
					 | 
				
			||||||
import random
 | 
					 | 
				
			||||||
import urllib.parse
 | 
					 | 
				
			||||||
import warnings
 | 
					 | 
				
			||||||
from types import TracebackType
 | 
					 | 
				
			||||||
from typing import (
 | 
					 | 
				
			||||||
    Any,
 | 
					 | 
				
			||||||
    AsyncIterator,
 | 
					 | 
				
			||||||
    Callable,
 | 
					 | 
				
			||||||
    Generator,
 | 
					 | 
				
			||||||
    Sequence,
 | 
					 | 
				
			||||||
    cast,
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from ..asyncio.compatibility import asyncio_timeout
 | 
					 | 
				
			||||||
from ..datastructures import Headers, HeadersLike
 | 
					 | 
				
			||||||
from ..exceptions import (
 | 
					 | 
				
			||||||
    InvalidHeader,
 | 
					 | 
				
			||||||
    InvalidHeaderValue,
 | 
					 | 
				
			||||||
    NegotiationError,
 | 
					 | 
				
			||||||
    SecurityError,
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
from ..extensions import ClientExtensionFactory, Extension
 | 
					 | 
				
			||||||
from ..extensions.permessage_deflate import enable_client_permessage_deflate
 | 
					 | 
				
			||||||
from ..headers import (
 | 
					 | 
				
			||||||
    build_authorization_basic,
 | 
					 | 
				
			||||||
    build_extension,
 | 
					 | 
				
			||||||
    build_host,
 | 
					 | 
				
			||||||
    build_subprotocol,
 | 
					 | 
				
			||||||
    parse_extension,
 | 
					 | 
				
			||||||
    parse_subprotocol,
 | 
					 | 
				
			||||||
    validate_subprotocols,
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
from ..http11 import USER_AGENT
 | 
					 | 
				
			||||||
from ..typing import ExtensionHeader, LoggerLike, Origin, Subprotocol
 | 
					 | 
				
			||||||
from ..uri import WebSocketURI, parse_uri
 | 
					 | 
				
			||||||
from .exceptions import InvalidMessage, InvalidStatusCode, RedirectHandshake
 | 
					 | 
				
			||||||
from .handshake import build_request, check_response
 | 
					 | 
				
			||||||
from .http import read_response
 | 
					 | 
				
			||||||
from .protocol import WebSocketCommonProtocol
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
__all__ = ["connect", "unix_connect", "WebSocketClientProtocol"]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class WebSocketClientProtocol(WebSocketCommonProtocol):
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    WebSocket client connection.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    :class:`WebSocketClientProtocol` provides :meth:`recv` and :meth:`send`
 | 
					 | 
				
			||||||
    coroutines for receiving and sending messages.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    It supports asynchronous iteration to receive messages::
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        async for message in websocket:
 | 
					 | 
				
			||||||
            await process(message)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    The iterator exits normally when the connection is closed with close code
 | 
					 | 
				
			||||||
    1000 (OK) or 1001 (going away) or without a close code. It raises
 | 
					 | 
				
			||||||
    a :exc:`~websockets.exceptions.ConnectionClosedError` when the connection
 | 
					 | 
				
			||||||
    is closed with any other code.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    See :func:`connect` for the documentation of ``logger``, ``origin``,
 | 
					 | 
				
			||||||
    ``extensions``, ``subprotocols``, ``extra_headers``, and
 | 
					 | 
				
			||||||
    ``user_agent_header``.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    See :class:`~websockets.legacy.protocol.WebSocketCommonProtocol` for the
 | 
					 | 
				
			||||||
    documentation of ``ping_interval``, ``ping_timeout``, ``close_timeout``,
 | 
					 | 
				
			||||||
    ``max_size``, ``max_queue``, ``read_limit``, and ``write_limit``.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    is_client = True
 | 
					 | 
				
			||||||
    side = "client"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        *,
 | 
					 | 
				
			||||||
        logger: LoggerLike | None = None,
 | 
					 | 
				
			||||||
        origin: Origin | None = None,
 | 
					 | 
				
			||||||
        extensions: Sequence[ClientExtensionFactory] | None = None,
 | 
					 | 
				
			||||||
        subprotocols: Sequence[Subprotocol] | None = None,
 | 
					 | 
				
			||||||
        extra_headers: HeadersLike | None = None,
 | 
					 | 
				
			||||||
        user_agent_header: str | None = USER_AGENT,
 | 
					 | 
				
			||||||
        **kwargs: Any,
 | 
					 | 
				
			||||||
    ) -> None:
 | 
					 | 
				
			||||||
        if logger is None:
 | 
					 | 
				
			||||||
            logger = logging.getLogger("websockets.client")
 | 
					 | 
				
			||||||
        super().__init__(logger=logger, **kwargs)
 | 
					 | 
				
			||||||
        self.origin = origin
 | 
					 | 
				
			||||||
        self.available_extensions = extensions
 | 
					 | 
				
			||||||
        self.available_subprotocols = subprotocols
 | 
					 | 
				
			||||||
        self.extra_headers = extra_headers
 | 
					 | 
				
			||||||
        self.user_agent_header = user_agent_header
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def write_http_request(self, path: str, headers: Headers) -> None:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Write request line and headers to the HTTP request.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        self.path = path
 | 
					 | 
				
			||||||
        self.request_headers = headers
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if self.debug:
 | 
					 | 
				
			||||||
            self.logger.debug("> GET %s HTTP/1.1", path)
 | 
					 | 
				
			||||||
            for key, value in headers.raw_items():
 | 
					 | 
				
			||||||
                self.logger.debug("> %s: %s", key, value)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Since the path and headers only contain ASCII characters,
 | 
					 | 
				
			||||||
        # we can keep this simple.
 | 
					 | 
				
			||||||
        request = f"GET {path} HTTP/1.1\r\n"
 | 
					 | 
				
			||||||
        request += str(headers)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.transport.write(request.encode())
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    async def read_http_response(self) -> tuple[int, Headers]:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Read status line and headers from the HTTP response.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        If the response contains a body, it may be read from ``self.reader``
 | 
					 | 
				
			||||||
        after this coroutine returns.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Raises:
 | 
					 | 
				
			||||||
            InvalidMessage: If the HTTP message is malformed or isn't an
 | 
					 | 
				
			||||||
                HTTP/1.1 GET response.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        try:
 | 
					 | 
				
			||||||
            status_code, reason, headers = await read_response(self.reader)
 | 
					 | 
				
			||||||
        except Exception as exc:
 | 
					 | 
				
			||||||
            raise InvalidMessage("did not receive a valid HTTP response") from exc
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if self.debug:
 | 
					 | 
				
			||||||
            self.logger.debug("< HTTP/1.1 %d %s", status_code, reason)
 | 
					 | 
				
			||||||
            for key, value in headers.raw_items():
 | 
					 | 
				
			||||||
                self.logger.debug("< %s: %s", key, value)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.response_headers = headers
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return status_code, self.response_headers
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @staticmethod
 | 
					 | 
				
			||||||
    def process_extensions(
 | 
					 | 
				
			||||||
        headers: Headers,
 | 
					 | 
				
			||||||
        available_extensions: Sequence[ClientExtensionFactory] | None,
 | 
					 | 
				
			||||||
    ) -> list[Extension]:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Handle the Sec-WebSocket-Extensions HTTP response header.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Check that each extension is supported, as well as its parameters.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Return the list of accepted extensions.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Raise :exc:`~websockets.exceptions.InvalidHandshake` to abort the
 | 
					 | 
				
			||||||
        connection.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        :rfc:`6455` leaves the rules up to the specification of each
 | 
					 | 
				
			||||||
        :extension.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        To provide this level of flexibility, for each extension accepted by
 | 
					 | 
				
			||||||
        the server, we check for a match with each extension available in the
 | 
					 | 
				
			||||||
        client configuration. If no match is found, an exception is raised.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        If several variants of the same extension are accepted by the server,
 | 
					 | 
				
			||||||
        it may be configured several times, which won't make sense in general.
 | 
					 | 
				
			||||||
        Extensions must implement their own requirements. For this purpose,
 | 
					 | 
				
			||||||
        the list of previously accepted extensions is provided.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Other requirements, for example related to mandatory extensions or the
 | 
					 | 
				
			||||||
        order of extensions, may be implemented by overriding this method.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        accepted_extensions: list[Extension] = []
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        header_values = headers.get_all("Sec-WebSocket-Extensions")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if header_values:
 | 
					 | 
				
			||||||
            if available_extensions is None:
 | 
					 | 
				
			||||||
                raise NegotiationError("no extensions supported")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            parsed_header_values: list[ExtensionHeader] = sum(
 | 
					 | 
				
			||||||
                [parse_extension(header_value) for header_value in header_values], []
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            for name, response_params in parsed_header_values:
 | 
					 | 
				
			||||||
                for extension_factory in available_extensions:
 | 
					 | 
				
			||||||
                    # Skip non-matching extensions based on their name.
 | 
					 | 
				
			||||||
                    if extension_factory.name != name:
 | 
					 | 
				
			||||||
                        continue
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                    # Skip non-matching extensions based on their params.
 | 
					 | 
				
			||||||
                    try:
 | 
					 | 
				
			||||||
                        extension = extension_factory.process_response_params(
 | 
					 | 
				
			||||||
                            response_params, accepted_extensions
 | 
					 | 
				
			||||||
                        )
 | 
					 | 
				
			||||||
                    except NegotiationError:
 | 
					 | 
				
			||||||
                        continue
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                    # Add matching extension to the final list.
 | 
					 | 
				
			||||||
                    accepted_extensions.append(extension)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                    # Break out of the loop once we have a match.
 | 
					 | 
				
			||||||
                    break
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                # If we didn't break from the loop, no extension in our list
 | 
					 | 
				
			||||||
                # matched what the server sent. Fail the connection.
 | 
					 | 
				
			||||||
                else:
 | 
					 | 
				
			||||||
                    raise NegotiationError(
 | 
					 | 
				
			||||||
                        f"Unsupported extension: "
 | 
					 | 
				
			||||||
                        f"name = {name}, params = {response_params}"
 | 
					 | 
				
			||||||
                    )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return accepted_extensions
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @staticmethod
 | 
					 | 
				
			||||||
    def process_subprotocol(
 | 
					 | 
				
			||||||
        headers: Headers, available_subprotocols: Sequence[Subprotocol] | None
 | 
					 | 
				
			||||||
    ) -> Subprotocol | None:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Handle the Sec-WebSocket-Protocol HTTP response header.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Check that it contains exactly one supported subprotocol.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Return the selected subprotocol.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        subprotocol: Subprotocol | None = None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        header_values = headers.get_all("Sec-WebSocket-Protocol")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if header_values:
 | 
					 | 
				
			||||||
            if available_subprotocols is None:
 | 
					 | 
				
			||||||
                raise NegotiationError("no subprotocols supported")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            parsed_header_values: Sequence[Subprotocol] = sum(
 | 
					 | 
				
			||||||
                [parse_subprotocol(header_value) for header_value in header_values], []
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            if len(parsed_header_values) > 1:
 | 
					 | 
				
			||||||
                raise InvalidHeaderValue(
 | 
					 | 
				
			||||||
                    "Sec-WebSocket-Protocol",
 | 
					 | 
				
			||||||
                    f"multiple values: {', '.join(parsed_header_values)}",
 | 
					 | 
				
			||||||
                )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            subprotocol = parsed_header_values[0]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            if subprotocol not in available_subprotocols:
 | 
					 | 
				
			||||||
                raise NegotiationError(f"unsupported subprotocol: {subprotocol}")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return subprotocol
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    async def handshake(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        wsuri: WebSocketURI,
 | 
					 | 
				
			||||||
        origin: Origin | None = None,
 | 
					 | 
				
			||||||
        available_extensions: Sequence[ClientExtensionFactory] | None = None,
 | 
					 | 
				
			||||||
        available_subprotocols: Sequence[Subprotocol] | None = None,
 | 
					 | 
				
			||||||
        extra_headers: HeadersLike | None = None,
 | 
					 | 
				
			||||||
    ) -> None:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Perform the client side of the opening handshake.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Args:
 | 
					 | 
				
			||||||
            wsuri: URI of the WebSocket server.
 | 
					 | 
				
			||||||
            origin: Value of the ``Origin`` header.
 | 
					 | 
				
			||||||
            extensions: List of supported extensions, in order in which they
 | 
					 | 
				
			||||||
                should be negotiated and run.
 | 
					 | 
				
			||||||
            subprotocols: List of supported subprotocols, in order of decreasing
 | 
					 | 
				
			||||||
                preference.
 | 
					 | 
				
			||||||
            extra_headers: Arbitrary HTTP headers to add to the handshake request.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Raises:
 | 
					 | 
				
			||||||
            InvalidHandshake: If the handshake fails.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        request_headers = Headers()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        request_headers["Host"] = build_host(wsuri.host, wsuri.port, wsuri.secure)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if wsuri.user_info:
 | 
					 | 
				
			||||||
            request_headers["Authorization"] = build_authorization_basic(
 | 
					 | 
				
			||||||
                *wsuri.user_info
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if origin is not None:
 | 
					 | 
				
			||||||
            request_headers["Origin"] = origin
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        key = build_request(request_headers)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if available_extensions is not None:
 | 
					 | 
				
			||||||
            extensions_header = build_extension(
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    (extension_factory.name, extension_factory.get_request_params())
 | 
					 | 
				
			||||||
                    for extension_factory in available_extensions
 | 
					 | 
				
			||||||
                ]
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
            request_headers["Sec-WebSocket-Extensions"] = extensions_header
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if available_subprotocols is not None:
 | 
					 | 
				
			||||||
            protocol_header = build_subprotocol(available_subprotocols)
 | 
					 | 
				
			||||||
            request_headers["Sec-WebSocket-Protocol"] = protocol_header
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if self.extra_headers is not None:
 | 
					 | 
				
			||||||
            request_headers.update(self.extra_headers)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if self.user_agent_header:
 | 
					 | 
				
			||||||
            request_headers.setdefault("User-Agent", self.user_agent_header)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.write_http_request(wsuri.resource_name, request_headers)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        status_code, response_headers = await self.read_http_response()
 | 
					 | 
				
			||||||
        if status_code in (301, 302, 303, 307, 308):
 | 
					 | 
				
			||||||
            if "Location" not in response_headers:
 | 
					 | 
				
			||||||
                raise InvalidHeader("Location")
 | 
					 | 
				
			||||||
            raise RedirectHandshake(response_headers["Location"])
 | 
					 | 
				
			||||||
        elif status_code != 101:
 | 
					 | 
				
			||||||
            raise InvalidStatusCode(status_code, response_headers)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        check_response(response_headers, key)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.extensions = self.process_extensions(
 | 
					 | 
				
			||||||
            response_headers, available_extensions
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.subprotocol = self.process_subprotocol(
 | 
					 | 
				
			||||||
            response_headers, available_subprotocols
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.connection_open()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class Connect:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Connect to the WebSocket server at ``uri``.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Awaiting :func:`connect` yields a :class:`WebSocketClientProtocol` which
 | 
					 | 
				
			||||||
    can then be used to send and receive messages.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    :func:`connect` can be used as a asynchronous context manager::
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        async with connect(...) as websocket:
 | 
					 | 
				
			||||||
            ...
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    The connection is closed automatically when exiting the context.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    :func:`connect` can be used as an infinite asynchronous iterator to
 | 
					 | 
				
			||||||
    reconnect automatically on errors::
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        async for websocket in connect(...):
 | 
					 | 
				
			||||||
            try:
 | 
					 | 
				
			||||||
                ...
 | 
					 | 
				
			||||||
            except websockets.ConnectionClosed:
 | 
					 | 
				
			||||||
                continue
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    The connection is closed automatically after each iteration of the loop.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    If an error occurs while establishing the connection, :func:`connect`
 | 
					 | 
				
			||||||
    retries with exponential backoff. The backoff delay starts at three
 | 
					 | 
				
			||||||
    seconds and increases up to one minute.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    If an error occurs in the body of the loop, you can handle the exception
 | 
					 | 
				
			||||||
    and :func:`connect` will reconnect with the next iteration; or you can
 | 
					 | 
				
			||||||
    let the exception bubble up and break out of the loop. This lets you
 | 
					 | 
				
			||||||
    decide which errors trigger a reconnection and which errors are fatal.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Args:
 | 
					 | 
				
			||||||
        uri: URI of the WebSocket server.
 | 
					 | 
				
			||||||
        create_protocol: Factory for the :class:`asyncio.Protocol` managing
 | 
					 | 
				
			||||||
            the connection. It defaults to :class:`WebSocketClientProtocol`.
 | 
					 | 
				
			||||||
            Set it to a wrapper or a subclass to customize connection handling.
 | 
					 | 
				
			||||||
        logger: Logger for this client.
 | 
					 | 
				
			||||||
            It defaults to ``logging.getLogger("websockets.client")``.
 | 
					 | 
				
			||||||
            See the :doc:`logging guide <../../topics/logging>` for details.
 | 
					 | 
				
			||||||
        compression: The "permessage-deflate" extension is enabled by default.
 | 
					 | 
				
			||||||
            Set ``compression`` to :obj:`None` to disable it. See the
 | 
					 | 
				
			||||||
            :doc:`compression guide <../../topics/compression>` for details.
 | 
					 | 
				
			||||||
        origin: Value of the ``Origin`` header, for servers that require it.
 | 
					 | 
				
			||||||
        extensions: List of supported extensions, in order in which they
 | 
					 | 
				
			||||||
            should be negotiated and run.
 | 
					 | 
				
			||||||
        subprotocols: List of supported subprotocols, in order of decreasing
 | 
					 | 
				
			||||||
            preference.
 | 
					 | 
				
			||||||
        extra_headers: Arbitrary HTTP headers to add to the handshake request.
 | 
					 | 
				
			||||||
        user_agent_header: Value of  the ``User-Agent`` request header.
 | 
					 | 
				
			||||||
            It defaults to ``"Python/x.y.z websockets/X.Y"``.
 | 
					 | 
				
			||||||
            Setting it to :obj:`None` removes the header.
 | 
					 | 
				
			||||||
        open_timeout: Timeout for opening the connection in seconds.
 | 
					 | 
				
			||||||
            :obj:`None` disables the timeout.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    See :class:`~websockets.legacy.protocol.WebSocketCommonProtocol` for the
 | 
					 | 
				
			||||||
    documentation of ``ping_interval``, ``ping_timeout``, ``close_timeout``,
 | 
					 | 
				
			||||||
    ``max_size``, ``max_queue``, ``read_limit``, and ``write_limit``.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Any other keyword arguments are passed the event loop's
 | 
					 | 
				
			||||||
    :meth:`~asyncio.loop.create_connection` method.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    For example:
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    * You can set ``ssl`` to a :class:`~ssl.SSLContext` to enforce TLS
 | 
					 | 
				
			||||||
      settings. When connecting to a ``wss://`` URI, if ``ssl`` isn't
 | 
					 | 
				
			||||||
      provided, a TLS context is created
 | 
					 | 
				
			||||||
      with :func:`~ssl.create_default_context`.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    * You can set ``host`` and ``port`` to connect to a different host and
 | 
					 | 
				
			||||||
      port from those found in ``uri``. This only changes the destination of
 | 
					 | 
				
			||||||
      the TCP connection. The host name from ``uri`` is still used in the TLS
 | 
					 | 
				
			||||||
      handshake for secure connections and in the ``Host`` header.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Raises:
 | 
					 | 
				
			||||||
        InvalidURI: If ``uri`` isn't a valid WebSocket URI.
 | 
					 | 
				
			||||||
        OSError: If the TCP connection fails.
 | 
					 | 
				
			||||||
        InvalidHandshake: If the opening handshake fails.
 | 
					 | 
				
			||||||
        ~asyncio.TimeoutError: If the opening handshake times out.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    MAX_REDIRECTS_ALLOWED = int(os.environ.get("WEBSOCKETS_MAX_REDIRECTS", "10"))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        uri: str,
 | 
					 | 
				
			||||||
        *,
 | 
					 | 
				
			||||||
        create_protocol: Callable[..., WebSocketClientProtocol] | None = None,
 | 
					 | 
				
			||||||
        logger: LoggerLike | None = None,
 | 
					 | 
				
			||||||
        compression: str | None = "deflate",
 | 
					 | 
				
			||||||
        origin: Origin | None = None,
 | 
					 | 
				
			||||||
        extensions: Sequence[ClientExtensionFactory] | None = None,
 | 
					 | 
				
			||||||
        subprotocols: Sequence[Subprotocol] | None = None,
 | 
					 | 
				
			||||||
        extra_headers: HeadersLike | None = None,
 | 
					 | 
				
			||||||
        user_agent_header: str | None = USER_AGENT,
 | 
					 | 
				
			||||||
        open_timeout: float | None = 10,
 | 
					 | 
				
			||||||
        ping_interval: float | None = 20,
 | 
					 | 
				
			||||||
        ping_timeout: float | None = 20,
 | 
					 | 
				
			||||||
        close_timeout: float | None = None,
 | 
					 | 
				
			||||||
        max_size: int | None = 2**20,
 | 
					 | 
				
			||||||
        max_queue: int | None = 2**5,
 | 
					 | 
				
			||||||
        read_limit: int = 2**16,
 | 
					 | 
				
			||||||
        write_limit: int = 2**16,
 | 
					 | 
				
			||||||
        **kwargs: Any,
 | 
					 | 
				
			||||||
    ) -> None:
 | 
					 | 
				
			||||||
        # Backwards compatibility: close_timeout used to be called timeout.
 | 
					 | 
				
			||||||
        timeout: float | None = kwargs.pop("timeout", None)
 | 
					 | 
				
			||||||
        if timeout is None:
 | 
					 | 
				
			||||||
            timeout = 10
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            warnings.warn("rename timeout to close_timeout", DeprecationWarning)
 | 
					 | 
				
			||||||
        # If both are specified, timeout is ignored.
 | 
					 | 
				
			||||||
        if close_timeout is None:
 | 
					 | 
				
			||||||
            close_timeout = timeout
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Backwards compatibility: create_protocol used to be called klass.
 | 
					 | 
				
			||||||
        klass: type[WebSocketClientProtocol] | None = kwargs.pop("klass", None)
 | 
					 | 
				
			||||||
        if klass is None:
 | 
					 | 
				
			||||||
            klass = WebSocketClientProtocol
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            warnings.warn("rename klass to create_protocol", DeprecationWarning)
 | 
					 | 
				
			||||||
        # If both are specified, klass is ignored.
 | 
					 | 
				
			||||||
        if create_protocol is None:
 | 
					 | 
				
			||||||
            create_protocol = klass
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Backwards compatibility: recv() used to return None on closed connections
 | 
					 | 
				
			||||||
        legacy_recv: bool = kwargs.pop("legacy_recv", False)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Backwards compatibility: the loop parameter used to be supported.
 | 
					 | 
				
			||||||
        _loop: asyncio.AbstractEventLoop | None = kwargs.pop("loop", None)
 | 
					 | 
				
			||||||
        if _loop is None:
 | 
					 | 
				
			||||||
            loop = asyncio.get_event_loop()
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            loop = _loop
 | 
					 | 
				
			||||||
            warnings.warn("remove loop argument", DeprecationWarning)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        wsuri = parse_uri(uri)
 | 
					 | 
				
			||||||
        if wsuri.secure:
 | 
					 | 
				
			||||||
            kwargs.setdefault("ssl", True)
 | 
					 | 
				
			||||||
        elif kwargs.get("ssl") is not None:
 | 
					 | 
				
			||||||
            raise ValueError(
 | 
					 | 
				
			||||||
                "connect() received a ssl argument for a ws:// URI, "
 | 
					 | 
				
			||||||
                "use a wss:// URI to enable TLS"
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if compression == "deflate":
 | 
					 | 
				
			||||||
            extensions = enable_client_permessage_deflate(extensions)
 | 
					 | 
				
			||||||
        elif compression is not None:
 | 
					 | 
				
			||||||
            raise ValueError(f"unsupported compression: {compression}")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if subprotocols is not None:
 | 
					 | 
				
			||||||
            validate_subprotocols(subprotocols)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Help mypy and avoid this error: "type[WebSocketClientProtocol] |
 | 
					 | 
				
			||||||
        # Callable[..., WebSocketClientProtocol]" not callable  [misc]
 | 
					 | 
				
			||||||
        create_protocol = cast(Callable[..., WebSocketClientProtocol], create_protocol)
 | 
					 | 
				
			||||||
        factory = functools.partial(
 | 
					 | 
				
			||||||
            create_protocol,
 | 
					 | 
				
			||||||
            logger=logger,
 | 
					 | 
				
			||||||
            origin=origin,
 | 
					 | 
				
			||||||
            extensions=extensions,
 | 
					 | 
				
			||||||
            subprotocols=subprotocols,
 | 
					 | 
				
			||||||
            extra_headers=extra_headers,
 | 
					 | 
				
			||||||
            user_agent_header=user_agent_header,
 | 
					 | 
				
			||||||
            ping_interval=ping_interval,
 | 
					 | 
				
			||||||
            ping_timeout=ping_timeout,
 | 
					 | 
				
			||||||
            close_timeout=close_timeout,
 | 
					 | 
				
			||||||
            max_size=max_size,
 | 
					 | 
				
			||||||
            max_queue=max_queue,
 | 
					 | 
				
			||||||
            read_limit=read_limit,
 | 
					 | 
				
			||||||
            write_limit=write_limit,
 | 
					 | 
				
			||||||
            host=wsuri.host,
 | 
					 | 
				
			||||||
            port=wsuri.port,
 | 
					 | 
				
			||||||
            secure=wsuri.secure,
 | 
					 | 
				
			||||||
            legacy_recv=legacy_recv,
 | 
					 | 
				
			||||||
            loop=_loop,
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if kwargs.pop("unix", False):
 | 
					 | 
				
			||||||
            path: str | None = kwargs.pop("path", None)
 | 
					 | 
				
			||||||
            create_connection = functools.partial(
 | 
					 | 
				
			||||||
                loop.create_unix_connection, factory, path, **kwargs
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            host: str | None
 | 
					 | 
				
			||||||
            port: int | None
 | 
					 | 
				
			||||||
            if kwargs.get("sock") is None:
 | 
					 | 
				
			||||||
                host, port = wsuri.host, wsuri.port
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                # If sock is given, host and port shouldn't be specified.
 | 
					 | 
				
			||||||
                host, port = None, None
 | 
					 | 
				
			||||||
                if kwargs.get("ssl"):
 | 
					 | 
				
			||||||
                    kwargs.setdefault("server_hostname", wsuri.host)
 | 
					 | 
				
			||||||
            # If host and port are given, override values from the URI.
 | 
					 | 
				
			||||||
            host = kwargs.pop("host", host)
 | 
					 | 
				
			||||||
            port = kwargs.pop("port", port)
 | 
					 | 
				
			||||||
            create_connection = functools.partial(
 | 
					 | 
				
			||||||
                loop.create_connection, factory, host, port, **kwargs
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.open_timeout = open_timeout
 | 
					 | 
				
			||||||
        if logger is None:
 | 
					 | 
				
			||||||
            logger = logging.getLogger("websockets.client")
 | 
					 | 
				
			||||||
        self.logger = logger
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # This is a coroutine function.
 | 
					 | 
				
			||||||
        self._create_connection = create_connection
 | 
					 | 
				
			||||||
        self._uri = uri
 | 
					 | 
				
			||||||
        self._wsuri = wsuri
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def handle_redirect(self, uri: str) -> None:
 | 
					 | 
				
			||||||
        # Update the state of this instance to connect to a new URI.
 | 
					 | 
				
			||||||
        old_uri = self._uri
 | 
					 | 
				
			||||||
        old_wsuri = self._wsuri
 | 
					 | 
				
			||||||
        new_uri = urllib.parse.urljoin(old_uri, uri)
 | 
					 | 
				
			||||||
        new_wsuri = parse_uri(new_uri)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Forbid TLS downgrade.
 | 
					 | 
				
			||||||
        if old_wsuri.secure and not new_wsuri.secure:
 | 
					 | 
				
			||||||
            raise SecurityError("redirect from WSS to WS")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        same_origin = (
 | 
					 | 
				
			||||||
            old_wsuri.secure == new_wsuri.secure
 | 
					 | 
				
			||||||
            and old_wsuri.host == new_wsuri.host
 | 
					 | 
				
			||||||
            and old_wsuri.port == new_wsuri.port
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Rewrite secure, host, and port for cross-origin redirects.
 | 
					 | 
				
			||||||
        # This preserves connection overrides with the host and port
 | 
					 | 
				
			||||||
        # arguments if the redirect points to the same host and port.
 | 
					 | 
				
			||||||
        if not same_origin:
 | 
					 | 
				
			||||||
            factory = self._create_connection.args[0]
 | 
					 | 
				
			||||||
            # Support TLS upgrade.
 | 
					 | 
				
			||||||
            if not old_wsuri.secure and new_wsuri.secure:
 | 
					 | 
				
			||||||
                factory.keywords["secure"] = True
 | 
					 | 
				
			||||||
                self._create_connection.keywords.setdefault("ssl", True)
 | 
					 | 
				
			||||||
            # Replace secure, host, and port arguments of the protocol factory.
 | 
					 | 
				
			||||||
            factory = functools.partial(
 | 
					 | 
				
			||||||
                factory.func,
 | 
					 | 
				
			||||||
                *factory.args,
 | 
					 | 
				
			||||||
                **dict(factory.keywords, host=new_wsuri.host, port=new_wsuri.port),
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
            # Replace secure, host, and port arguments of create_connection.
 | 
					 | 
				
			||||||
            self._create_connection = functools.partial(
 | 
					 | 
				
			||||||
                self._create_connection.func,
 | 
					 | 
				
			||||||
                *(factory, new_wsuri.host, new_wsuri.port),
 | 
					 | 
				
			||||||
                **self._create_connection.keywords,
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Set the new WebSocket URI. This suffices for same-origin redirects.
 | 
					 | 
				
			||||||
        self._uri = new_uri
 | 
					 | 
				
			||||||
        self._wsuri = new_wsuri
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # async for ... in connect(...):
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    BACKOFF_INITIAL = float(os.environ.get("WEBSOCKETS_BACKOFF_INITIAL_DELAY", "5"))
 | 
					 | 
				
			||||||
    BACKOFF_MIN = float(os.environ.get("WEBSOCKETS_BACKOFF_MIN_DELAY", "3.1"))
 | 
					 | 
				
			||||||
    BACKOFF_MAX = float(os.environ.get("WEBSOCKETS_BACKOFF_MAX_DELAY", "90.0"))
 | 
					 | 
				
			||||||
    BACKOFF_FACTOR = float(os.environ.get("WEBSOCKETS_BACKOFF_FACTOR", "1.618"))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    async def __aiter__(self) -> AsyncIterator[WebSocketClientProtocol]:
 | 
					 | 
				
			||||||
        backoff_delay = self.BACKOFF_MIN / self.BACKOFF_FACTOR
 | 
					 | 
				
			||||||
        while True:
 | 
					 | 
				
			||||||
            try:
 | 
					 | 
				
			||||||
                async with self as protocol:
 | 
					 | 
				
			||||||
                    yield protocol
 | 
					 | 
				
			||||||
            except Exception:
 | 
					 | 
				
			||||||
                # Add a random initial delay between 0 and 5 seconds.
 | 
					 | 
				
			||||||
                # See 7.2.3. Recovering from Abnormal Closure in RFC 6455.
 | 
					 | 
				
			||||||
                if backoff_delay == self.BACKOFF_MIN:
 | 
					 | 
				
			||||||
                    initial_delay = random.random() * self.BACKOFF_INITIAL
 | 
					 | 
				
			||||||
                    self.logger.info(
 | 
					 | 
				
			||||||
                        "! connect failed; reconnecting in %.1f seconds",
 | 
					 | 
				
			||||||
                        initial_delay,
 | 
					 | 
				
			||||||
                        exc_info=True,
 | 
					 | 
				
			||||||
                    )
 | 
					 | 
				
			||||||
                    await asyncio.sleep(initial_delay)
 | 
					 | 
				
			||||||
                else:
 | 
					 | 
				
			||||||
                    self.logger.info(
 | 
					 | 
				
			||||||
                        "! connect failed again; retrying in %d seconds",
 | 
					 | 
				
			||||||
                        int(backoff_delay),
 | 
					 | 
				
			||||||
                        exc_info=True,
 | 
					 | 
				
			||||||
                    )
 | 
					 | 
				
			||||||
                    await asyncio.sleep(int(backoff_delay))
 | 
					 | 
				
			||||||
                # Increase delay with truncated exponential backoff.
 | 
					 | 
				
			||||||
                backoff_delay = backoff_delay * self.BACKOFF_FACTOR
 | 
					 | 
				
			||||||
                backoff_delay = min(backoff_delay, self.BACKOFF_MAX)
 | 
					 | 
				
			||||||
                continue
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                # Connection succeeded - reset backoff delay
 | 
					 | 
				
			||||||
                backoff_delay = self.BACKOFF_MIN
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # async with connect(...) as ...:
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    async def __aenter__(self) -> WebSocketClientProtocol:
 | 
					 | 
				
			||||||
        return await self
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    async def __aexit__(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        exc_type: type[BaseException] | None,
 | 
					 | 
				
			||||||
        exc_value: BaseException | None,
 | 
					 | 
				
			||||||
        traceback: TracebackType | None,
 | 
					 | 
				
			||||||
    ) -> None:
 | 
					 | 
				
			||||||
        await self.protocol.close()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # ... = await connect(...)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __await__(self) -> Generator[Any, None, WebSocketClientProtocol]:
 | 
					 | 
				
			||||||
        # Create a suitable iterator by calling __await__ on a coroutine.
 | 
					 | 
				
			||||||
        return self.__await_impl__().__await__()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    async def __await_impl__(self) -> WebSocketClientProtocol:
 | 
					 | 
				
			||||||
        async with asyncio_timeout(self.open_timeout):
 | 
					 | 
				
			||||||
            for _redirects in range(self.MAX_REDIRECTS_ALLOWED):
 | 
					 | 
				
			||||||
                _transport, protocol = await self._create_connection()
 | 
					 | 
				
			||||||
                try:
 | 
					 | 
				
			||||||
                    await protocol.handshake(
 | 
					 | 
				
			||||||
                        self._wsuri,
 | 
					 | 
				
			||||||
                        origin=protocol.origin,
 | 
					 | 
				
			||||||
                        available_extensions=protocol.available_extensions,
 | 
					 | 
				
			||||||
                        available_subprotocols=protocol.available_subprotocols,
 | 
					 | 
				
			||||||
                        extra_headers=protocol.extra_headers,
 | 
					 | 
				
			||||||
                    )
 | 
					 | 
				
			||||||
                except RedirectHandshake as exc:
 | 
					 | 
				
			||||||
                    protocol.fail_connection()
 | 
					 | 
				
			||||||
                    await protocol.wait_closed()
 | 
					 | 
				
			||||||
                    self.handle_redirect(exc.uri)
 | 
					 | 
				
			||||||
                # Avoid leaking a connected socket when the handshake fails.
 | 
					 | 
				
			||||||
                except (Exception, asyncio.CancelledError):
 | 
					 | 
				
			||||||
                    protocol.fail_connection()
 | 
					 | 
				
			||||||
                    await protocol.wait_closed()
 | 
					 | 
				
			||||||
                    raise
 | 
					 | 
				
			||||||
                else:
 | 
					 | 
				
			||||||
                    self.protocol = protocol
 | 
					 | 
				
			||||||
                    return protocol
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                raise SecurityError("too many redirects")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # ... = yield from connect(...) - remove when dropping Python < 3.10
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    __iter__ = __await__
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
connect = Connect
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def unix_connect(
 | 
					 | 
				
			||||||
    path: str | None = None,
 | 
					 | 
				
			||||||
    uri: str = "ws://localhost/",
 | 
					 | 
				
			||||||
    **kwargs: Any,
 | 
					 | 
				
			||||||
) -> Connect:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Similar to :func:`connect`, but for connecting to a Unix socket.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    This function builds upon the event loop's
 | 
					 | 
				
			||||||
    :meth:`~asyncio.loop.create_unix_connection` method.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    It is only available on Unix.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    It's mainly useful for debugging servers listening on Unix sockets.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Args:
 | 
					 | 
				
			||||||
        path: File system path to the Unix socket.
 | 
					 | 
				
			||||||
        uri: URI of the WebSocket server; the host is used in the TLS
 | 
					 | 
				
			||||||
            handshake for secure connections and in the ``Host`` header.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    return connect(uri=uri, path=path, unix=True, **kwargs)
 | 
					 | 
				
			||||||
@@ -1,78 +0,0 @@
 | 
				
			|||||||
import http
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from .. import datastructures
 | 
					 | 
				
			||||||
from ..exceptions import (
 | 
					 | 
				
			||||||
    InvalidHandshake,
 | 
					 | 
				
			||||||
    ProtocolError as WebSocketProtocolError,  # noqa: F401
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
from ..typing import StatusLike
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class InvalidMessage(InvalidHandshake):
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Raised when a handshake request or response is malformed.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class InvalidStatusCode(InvalidHandshake):
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Raised when a handshake response status code is invalid.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(self, status_code: int, headers: datastructures.Headers) -> None:
 | 
					 | 
				
			||||||
        self.status_code = status_code
 | 
					 | 
				
			||||||
        self.headers = headers
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __str__(self) -> str:
 | 
					 | 
				
			||||||
        return f"server rejected WebSocket connection: HTTP {self.status_code}"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class AbortHandshake(InvalidHandshake):
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Raised to abort the handshake on purpose and return an HTTP response.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    This exception is an implementation detail.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    The public API is
 | 
					 | 
				
			||||||
    :meth:`~websockets.legacy.server.WebSocketServerProtocol.process_request`.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Attributes:
 | 
					 | 
				
			||||||
        status (~http.HTTPStatus): HTTP status code.
 | 
					 | 
				
			||||||
        headers (Headers): HTTP response headers.
 | 
					 | 
				
			||||||
        body (bytes): HTTP response body.
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        status: StatusLike,
 | 
					 | 
				
			||||||
        headers: datastructures.HeadersLike,
 | 
					 | 
				
			||||||
        body: bytes = b"",
 | 
					 | 
				
			||||||
    ) -> None:
 | 
					 | 
				
			||||||
        # If a user passes an int instead of a HTTPStatus, fix it automatically.
 | 
					 | 
				
			||||||
        self.status = http.HTTPStatus(status)
 | 
					 | 
				
			||||||
        self.headers = datastructures.Headers(headers)
 | 
					 | 
				
			||||||
        self.body = body
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __str__(self) -> str:
 | 
					 | 
				
			||||||
        return (
 | 
					 | 
				
			||||||
            f"HTTP {self.status:d}, "
 | 
					 | 
				
			||||||
            f"{len(self.headers)} headers, "
 | 
					 | 
				
			||||||
            f"{len(self.body)} bytes"
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class RedirectHandshake(InvalidHandshake):
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Raised when a handshake gets redirected.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    This exception is an implementation detail.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(self, uri: str) -> None:
 | 
					 | 
				
			||||||
        self.uri = uri
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __str__(self) -> str:
 | 
					 | 
				
			||||||
        return f"redirect to {self.uri}"
 | 
					 | 
				
			||||||
@@ -1,224 +0,0 @@
 | 
				
			|||||||
from __future__ import annotations
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import struct
 | 
					 | 
				
			||||||
from typing import Any, Awaitable, Callable, NamedTuple, Sequence
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from .. import extensions, frames
 | 
					 | 
				
			||||||
from ..exceptions import PayloadTooBig, ProtocolError
 | 
					 | 
				
			||||||
from ..frames import BytesLike
 | 
					 | 
				
			||||||
from ..typing import Data
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
try:
 | 
					 | 
				
			||||||
    from ..speedups import apply_mask
 | 
					 | 
				
			||||||
except ImportError:
 | 
					 | 
				
			||||||
    from ..utils import apply_mask
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class Frame(NamedTuple):
 | 
					 | 
				
			||||||
    fin: bool
 | 
					 | 
				
			||||||
    opcode: frames.Opcode
 | 
					 | 
				
			||||||
    data: bytes
 | 
					 | 
				
			||||||
    rsv1: bool = False
 | 
					 | 
				
			||||||
    rsv2: bool = False
 | 
					 | 
				
			||||||
    rsv3: bool = False
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @property
 | 
					 | 
				
			||||||
    def new_frame(self) -> frames.Frame:
 | 
					 | 
				
			||||||
        return frames.Frame(
 | 
					 | 
				
			||||||
            self.opcode,
 | 
					 | 
				
			||||||
            self.data,
 | 
					 | 
				
			||||||
            self.fin,
 | 
					 | 
				
			||||||
            self.rsv1,
 | 
					 | 
				
			||||||
            self.rsv2,
 | 
					 | 
				
			||||||
            self.rsv3,
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __str__(self) -> str:
 | 
					 | 
				
			||||||
        return str(self.new_frame)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def check(self) -> None:
 | 
					 | 
				
			||||||
        return self.new_frame.check()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @classmethod
 | 
					 | 
				
			||||||
    async def read(
 | 
					 | 
				
			||||||
        cls,
 | 
					 | 
				
			||||||
        reader: Callable[[int], Awaitable[bytes]],
 | 
					 | 
				
			||||||
        *,
 | 
					 | 
				
			||||||
        mask: bool,
 | 
					 | 
				
			||||||
        max_size: int | None = None,
 | 
					 | 
				
			||||||
        extensions: Sequence[extensions.Extension] | None = None,
 | 
					 | 
				
			||||||
    ) -> Frame:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Read a WebSocket frame.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Args:
 | 
					 | 
				
			||||||
            reader: Coroutine that reads exactly the requested number of
 | 
					 | 
				
			||||||
                bytes, unless the end of file is reached.
 | 
					 | 
				
			||||||
            mask: Whether the frame should be masked i.e. whether the read
 | 
					 | 
				
			||||||
                happens on the server side.
 | 
					 | 
				
			||||||
            max_size: Maximum payload size in bytes.
 | 
					 | 
				
			||||||
            extensions: List of extensions, applied in reverse order.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Raises:
 | 
					 | 
				
			||||||
            PayloadTooBig: If the frame exceeds ``max_size``.
 | 
					 | 
				
			||||||
            ProtocolError: If the frame contains incorrect values.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Read the header.
 | 
					 | 
				
			||||||
        data = await reader(2)
 | 
					 | 
				
			||||||
        head1, head2 = struct.unpack("!BB", data)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # While not Pythonic, this is marginally faster than calling bool().
 | 
					 | 
				
			||||||
        fin = True if head1 & 0b10000000 else False
 | 
					 | 
				
			||||||
        rsv1 = True if head1 & 0b01000000 else False
 | 
					 | 
				
			||||||
        rsv2 = True if head1 & 0b00100000 else False
 | 
					 | 
				
			||||||
        rsv3 = True if head1 & 0b00010000 else False
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        try:
 | 
					 | 
				
			||||||
            opcode = frames.Opcode(head1 & 0b00001111)
 | 
					 | 
				
			||||||
        except ValueError as exc:
 | 
					 | 
				
			||||||
            raise ProtocolError("invalid opcode") from exc
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if (True if head2 & 0b10000000 else False) != mask:
 | 
					 | 
				
			||||||
            raise ProtocolError("incorrect masking")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        length = head2 & 0b01111111
 | 
					 | 
				
			||||||
        if length == 126:
 | 
					 | 
				
			||||||
            data = await reader(2)
 | 
					 | 
				
			||||||
            (length,) = struct.unpack("!H", data)
 | 
					 | 
				
			||||||
        elif length == 127:
 | 
					 | 
				
			||||||
            data = await reader(8)
 | 
					 | 
				
			||||||
            (length,) = struct.unpack("!Q", data)
 | 
					 | 
				
			||||||
        if max_size is not None and length > max_size:
 | 
					 | 
				
			||||||
            raise PayloadTooBig(f"over size limit ({length} > {max_size} bytes)")
 | 
					 | 
				
			||||||
        if mask:
 | 
					 | 
				
			||||||
            mask_bits = await reader(4)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Read the data.
 | 
					 | 
				
			||||||
        data = await reader(length)
 | 
					 | 
				
			||||||
        if mask:
 | 
					 | 
				
			||||||
            data = apply_mask(data, mask_bits)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        new_frame = frames.Frame(opcode, data, fin, rsv1, rsv2, rsv3)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if extensions is None:
 | 
					 | 
				
			||||||
            extensions = []
 | 
					 | 
				
			||||||
        for extension in reversed(extensions):
 | 
					 | 
				
			||||||
            new_frame = extension.decode(new_frame, max_size=max_size)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        new_frame.check()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return cls(
 | 
					 | 
				
			||||||
            new_frame.fin,
 | 
					 | 
				
			||||||
            new_frame.opcode,
 | 
					 | 
				
			||||||
            new_frame.data,
 | 
					 | 
				
			||||||
            new_frame.rsv1,
 | 
					 | 
				
			||||||
            new_frame.rsv2,
 | 
					 | 
				
			||||||
            new_frame.rsv3,
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def write(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        write: Callable[[bytes], Any],
 | 
					 | 
				
			||||||
        *,
 | 
					 | 
				
			||||||
        mask: bool,
 | 
					 | 
				
			||||||
        extensions: Sequence[extensions.Extension] | None = None,
 | 
					 | 
				
			||||||
    ) -> None:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Write a WebSocket frame.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Args:
 | 
					 | 
				
			||||||
            frame: Frame to write.
 | 
					 | 
				
			||||||
            write: Function that writes bytes.
 | 
					 | 
				
			||||||
            mask: Whether the frame should be masked i.e. whether the write
 | 
					 | 
				
			||||||
                happens on the client side.
 | 
					 | 
				
			||||||
            extensions: List of extensions, applied in order.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Raises:
 | 
					 | 
				
			||||||
            ProtocolError: If the frame contains incorrect values.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        # The frame is written in a single call to write in order to prevent
 | 
					 | 
				
			||||||
        # TCP fragmentation. See #68 for details. This also makes it safe to
 | 
					 | 
				
			||||||
        # send frames concurrently from multiple coroutines.
 | 
					 | 
				
			||||||
        write(self.new_frame.serialize(mask=mask, extensions=extensions))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def prepare_data(data: Data) -> tuple[int, bytes]:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Convert a string or byte-like object to an opcode and a bytes-like object.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    This function is designed for data frames.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    If ``data`` is a :class:`str`, return ``OP_TEXT`` and a :class:`bytes`
 | 
					 | 
				
			||||||
    object encoding ``data`` in UTF-8.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    If ``data`` is a bytes-like object, return ``OP_BINARY`` and a bytes-like
 | 
					 | 
				
			||||||
    object.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Raises:
 | 
					 | 
				
			||||||
        TypeError: If ``data`` doesn't have a supported type.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    if isinstance(data, str):
 | 
					 | 
				
			||||||
        return frames.Opcode.TEXT, data.encode()
 | 
					 | 
				
			||||||
    elif isinstance(data, BytesLike):
 | 
					 | 
				
			||||||
        return frames.Opcode.BINARY, data
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        raise TypeError("data must be str or bytes-like")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def prepare_ctrl(data: Data) -> bytes:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Convert a string or byte-like object to bytes.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    This function is designed for ping and pong frames.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    If ``data`` is a :class:`str`, return a :class:`bytes` object encoding
 | 
					 | 
				
			||||||
    ``data`` in UTF-8.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    If ``data`` is a bytes-like object, return a :class:`bytes` object.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Raises:
 | 
					 | 
				
			||||||
        TypeError: If ``data`` doesn't have a supported type.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    if isinstance(data, str):
 | 
					 | 
				
			||||||
        return data.encode()
 | 
					 | 
				
			||||||
    elif isinstance(data, BytesLike):
 | 
					 | 
				
			||||||
        return bytes(data)
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        raise TypeError("data must be str or bytes-like")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# Backwards compatibility with previously documented public APIs
 | 
					 | 
				
			||||||
encode_data = prepare_ctrl
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# Backwards compatibility with previously documented public APIs
 | 
					 | 
				
			||||||
from ..frames import Close  # noqa: E402 F401, I001
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def parse_close(data: bytes) -> tuple[int, str]:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Parse the payload from a close frame.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Returns:
 | 
					 | 
				
			||||||
        Close code and reason.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Raises:
 | 
					 | 
				
			||||||
        ProtocolError: If data is ill-formed.
 | 
					 | 
				
			||||||
        UnicodeDecodeError: If the reason isn't valid UTF-8.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    close = Close.parse(data)
 | 
					 | 
				
			||||||
    return close.code, close.reason
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def serialize_close(code: int, reason: str) -> bytes:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Serialize the payload for a close frame.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    return Close(code, reason).serialize()
 | 
					 | 
				
			||||||
@@ -1,158 +0,0 @@
 | 
				
			|||||||
from __future__ import annotations
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import base64
 | 
					 | 
				
			||||||
import binascii
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from ..datastructures import Headers, MultipleValuesError
 | 
					 | 
				
			||||||
from ..exceptions import InvalidHeader, InvalidHeaderValue, InvalidUpgrade
 | 
					 | 
				
			||||||
from ..headers import parse_connection, parse_upgrade
 | 
					 | 
				
			||||||
from ..typing import ConnectionOption, UpgradeProtocol
 | 
					 | 
				
			||||||
from ..utils import accept_key as accept, generate_key
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
__all__ = ["build_request", "check_request", "build_response", "check_response"]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def build_request(headers: Headers) -> str:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Build a handshake request to send to the server.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Update request headers passed in argument.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Args:
 | 
					 | 
				
			||||||
        headers: Handshake request headers.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Returns:
 | 
					 | 
				
			||||||
        ``key`` that must be passed to :func:`check_response`.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    key = generate_key()
 | 
					 | 
				
			||||||
    headers["Upgrade"] = "websocket"
 | 
					 | 
				
			||||||
    headers["Connection"] = "Upgrade"
 | 
					 | 
				
			||||||
    headers["Sec-WebSocket-Key"] = key
 | 
					 | 
				
			||||||
    headers["Sec-WebSocket-Version"] = "13"
 | 
					 | 
				
			||||||
    return key
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def check_request(headers: Headers) -> str:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Check a handshake request received from the client.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    This function doesn't verify that the request is an HTTP/1.1 or higher GET
 | 
					 | 
				
			||||||
    request and doesn't perform ``Host`` and ``Origin`` checks. These controls
 | 
					 | 
				
			||||||
    are usually performed earlier in the HTTP request handling code. They're
 | 
					 | 
				
			||||||
    the responsibility of the caller.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Args:
 | 
					 | 
				
			||||||
        headers: Handshake request headers.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Returns:
 | 
					 | 
				
			||||||
        ``key`` that must be passed to :func:`build_response`.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Raises:
 | 
					 | 
				
			||||||
        InvalidHandshake: If the handshake request is invalid.
 | 
					 | 
				
			||||||
            Then, the server must return a 400 Bad Request error.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    connection: list[ConnectionOption] = sum(
 | 
					 | 
				
			||||||
        [parse_connection(value) for value in headers.get_all("Connection")], []
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if not any(value.lower() == "upgrade" for value in connection):
 | 
					 | 
				
			||||||
        raise InvalidUpgrade("Connection", ", ".join(connection))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    upgrade: list[UpgradeProtocol] = sum(
 | 
					 | 
				
			||||||
        [parse_upgrade(value) for value in headers.get_all("Upgrade")], []
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # For compatibility with non-strict implementations, ignore case when
 | 
					 | 
				
			||||||
    # checking the Upgrade header. The RFC always uses "websocket", except
 | 
					 | 
				
			||||||
    # in section 11.2. (IANA registration) where it uses "WebSocket".
 | 
					 | 
				
			||||||
    if not (len(upgrade) == 1 and upgrade[0].lower() == "websocket"):
 | 
					 | 
				
			||||||
        raise InvalidUpgrade("Upgrade", ", ".join(upgrade))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    try:
 | 
					 | 
				
			||||||
        s_w_key = headers["Sec-WebSocket-Key"]
 | 
					 | 
				
			||||||
    except KeyError as exc:
 | 
					 | 
				
			||||||
        raise InvalidHeader("Sec-WebSocket-Key") from exc
 | 
					 | 
				
			||||||
    except MultipleValuesError as exc:
 | 
					 | 
				
			||||||
        raise InvalidHeader("Sec-WebSocket-Key", "multiple values") from exc
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    try:
 | 
					 | 
				
			||||||
        raw_key = base64.b64decode(s_w_key.encode(), validate=True)
 | 
					 | 
				
			||||||
    except binascii.Error as exc:
 | 
					 | 
				
			||||||
        raise InvalidHeaderValue("Sec-WebSocket-Key", s_w_key) from exc
 | 
					 | 
				
			||||||
    if len(raw_key) != 16:
 | 
					 | 
				
			||||||
        raise InvalidHeaderValue("Sec-WebSocket-Key", s_w_key)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    try:
 | 
					 | 
				
			||||||
        s_w_version = headers["Sec-WebSocket-Version"]
 | 
					 | 
				
			||||||
    except KeyError as exc:
 | 
					 | 
				
			||||||
        raise InvalidHeader("Sec-WebSocket-Version") from exc
 | 
					 | 
				
			||||||
    except MultipleValuesError as exc:
 | 
					 | 
				
			||||||
        raise InvalidHeader("Sec-WebSocket-Version", "multiple values") from exc
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if s_w_version != "13":
 | 
					 | 
				
			||||||
        raise InvalidHeaderValue("Sec-WebSocket-Version", s_w_version)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    return s_w_key
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def build_response(headers: Headers, key: str) -> None:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Build a handshake response to send to the client.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Update response headers passed in argument.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Args:
 | 
					 | 
				
			||||||
        headers: Handshake response headers.
 | 
					 | 
				
			||||||
        key: Returned by :func:`check_request`.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    headers["Upgrade"] = "websocket"
 | 
					 | 
				
			||||||
    headers["Connection"] = "Upgrade"
 | 
					 | 
				
			||||||
    headers["Sec-WebSocket-Accept"] = accept(key)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def check_response(headers: Headers, key: str) -> None:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Check a handshake response received from the server.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    This function doesn't verify that the response is an HTTP/1.1 or higher
 | 
					 | 
				
			||||||
    response with a 101 status code. These controls are the responsibility of
 | 
					 | 
				
			||||||
    the caller.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Args:
 | 
					 | 
				
			||||||
        headers: Handshake response headers.
 | 
					 | 
				
			||||||
        key: Returned by :func:`build_request`.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Raises:
 | 
					 | 
				
			||||||
        InvalidHandshake: If the handshake response is invalid.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    connection: list[ConnectionOption] = sum(
 | 
					 | 
				
			||||||
        [parse_connection(value) for value in headers.get_all("Connection")], []
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if not any(value.lower() == "upgrade" for value in connection):
 | 
					 | 
				
			||||||
        raise InvalidUpgrade("Connection", " ".join(connection))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    upgrade: list[UpgradeProtocol] = sum(
 | 
					 | 
				
			||||||
        [parse_upgrade(value) for value in headers.get_all("Upgrade")], []
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # For compatibility with non-strict implementations, ignore case when
 | 
					 | 
				
			||||||
    # checking the Upgrade header. The RFC always uses "websocket", except
 | 
					 | 
				
			||||||
    # in section 11.2. (IANA registration) where it uses "WebSocket".
 | 
					 | 
				
			||||||
    if not (len(upgrade) == 1 and upgrade[0].lower() == "websocket"):
 | 
					 | 
				
			||||||
        raise InvalidUpgrade("Upgrade", ", ".join(upgrade))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    try:
 | 
					 | 
				
			||||||
        s_w_accept = headers["Sec-WebSocket-Accept"]
 | 
					 | 
				
			||||||
    except KeyError as exc:
 | 
					 | 
				
			||||||
        raise InvalidHeader("Sec-WebSocket-Accept") from exc
 | 
					 | 
				
			||||||
    except MultipleValuesError as exc:
 | 
					 | 
				
			||||||
        raise InvalidHeader("Sec-WebSocket-Accept", "multiple values") from exc
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if s_w_accept != accept(key):
 | 
					 | 
				
			||||||
        raise InvalidHeaderValue("Sec-WebSocket-Accept", s_w_accept)
 | 
					 | 
				
			||||||
@@ -1,201 +0,0 @@
 | 
				
			|||||||
from __future__ import annotations
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import asyncio
 | 
					 | 
				
			||||||
import os
 | 
					 | 
				
			||||||
import re
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from ..datastructures import Headers
 | 
					 | 
				
			||||||
from ..exceptions import SecurityError
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
__all__ = ["read_request", "read_response"]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
MAX_NUM_HEADERS = int(os.environ.get("WEBSOCKETS_MAX_NUM_HEADERS", "128"))
 | 
					 | 
				
			||||||
MAX_LINE_LENGTH = int(os.environ.get("WEBSOCKETS_MAX_LINE_LENGTH", "8192"))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def d(value: bytes) -> str:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Decode a bytestring for interpolating into an error message.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    return value.decode(errors="backslashreplace")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# See https://datatracker.ietf.org/doc/html/rfc7230#appendix-B.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# Regex for validating header names.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
_token_re = re.compile(rb"[-!#$%&\'*+.^_`|~0-9a-zA-Z]+")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# Regex for validating header values.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# We don't attempt to support obsolete line folding.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# Include HTAB (\x09), SP (\x20), VCHAR (\x21-\x7e), obs-text (\x80-\xff).
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# The ABNF is complicated because it attempts to express that optional
 | 
					 | 
				
			||||||
# whitespace is ignored. We strip whitespace and don't revalidate that.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# See also https://www.rfc-editor.org/errata_search.php?rfc=7230&eid=4189
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
_value_re = re.compile(rb"[\x09\x20-\x7e\x80-\xff]*")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
async def read_request(stream: asyncio.StreamReader) -> tuple[str, Headers]:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Read an HTTP/1.1 GET request and return ``(path, headers)``.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    ``path`` isn't URL-decoded or validated in any way.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    ``path`` and ``headers`` are expected to contain only ASCII characters.
 | 
					 | 
				
			||||||
    Other characters are represented with surrogate escapes.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    :func:`read_request` doesn't attempt to read the request body because
 | 
					 | 
				
			||||||
    WebSocket handshake requests don't have one. If the request contains a
 | 
					 | 
				
			||||||
    body, it may be read from ``stream`` after this coroutine returns.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Args:
 | 
					 | 
				
			||||||
        stream: Input to read the request from.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Raises:
 | 
					 | 
				
			||||||
        EOFError: If the connection is closed without a full HTTP request.
 | 
					 | 
				
			||||||
        SecurityError: If the request exceeds a security limit.
 | 
					 | 
				
			||||||
        ValueError: If the request isn't well formatted.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    # https://datatracker.ietf.org/doc/html/rfc7230#section-3.1.1
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Parsing is simple because fixed values are expected for method and
 | 
					 | 
				
			||||||
    # version and because path isn't checked. Since WebSocket software tends
 | 
					 | 
				
			||||||
    # to implement HTTP/1.1 strictly, there's little need for lenient parsing.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    try:
 | 
					 | 
				
			||||||
        request_line = await read_line(stream)
 | 
					 | 
				
			||||||
    except EOFError as exc:
 | 
					 | 
				
			||||||
        raise EOFError("connection closed while reading HTTP request line") from exc
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    try:
 | 
					 | 
				
			||||||
        method, raw_path, version = request_line.split(b" ", 2)
 | 
					 | 
				
			||||||
    except ValueError:  # not enough values to unpack (expected 3, got 1-2)
 | 
					 | 
				
			||||||
        raise ValueError(f"invalid HTTP request line: {d(request_line)}") from None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if method != b"GET":
 | 
					 | 
				
			||||||
        raise ValueError(f"unsupported HTTP method: {d(method)}")
 | 
					 | 
				
			||||||
    if version != b"HTTP/1.1":
 | 
					 | 
				
			||||||
        raise ValueError(f"unsupported HTTP version: {d(version)}")
 | 
					 | 
				
			||||||
    path = raw_path.decode("ascii", "surrogateescape")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    headers = await read_headers(stream)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    return path, headers
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
async def read_response(stream: asyncio.StreamReader) -> tuple[int, str, Headers]:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Read an HTTP/1.1 response and return ``(status_code, reason, headers)``.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    ``reason`` and ``headers`` are expected to contain only ASCII characters.
 | 
					 | 
				
			||||||
    Other characters are represented with surrogate escapes.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    :func:`read_request` doesn't attempt to read the response body because
 | 
					 | 
				
			||||||
    WebSocket handshake responses don't have one. If the response contains a
 | 
					 | 
				
			||||||
    body, it may be read from ``stream`` after this coroutine returns.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Args:
 | 
					 | 
				
			||||||
        stream: Input to read the response from.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Raises:
 | 
					 | 
				
			||||||
        EOFError: If the connection is closed without a full HTTP response.
 | 
					 | 
				
			||||||
        SecurityError: If the response exceeds a security limit.
 | 
					 | 
				
			||||||
        ValueError: If the response isn't well formatted.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    # https://datatracker.ietf.org/doc/html/rfc7230#section-3.1.2
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # As in read_request, parsing is simple because a fixed value is expected
 | 
					 | 
				
			||||||
    # for version, status_code is a 3-digit number, and reason can be ignored.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    try:
 | 
					 | 
				
			||||||
        status_line = await read_line(stream)
 | 
					 | 
				
			||||||
    except EOFError as exc:
 | 
					 | 
				
			||||||
        raise EOFError("connection closed while reading HTTP status line") from exc
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    try:
 | 
					 | 
				
			||||||
        version, raw_status_code, raw_reason = status_line.split(b" ", 2)
 | 
					 | 
				
			||||||
    except ValueError:  # not enough values to unpack (expected 3, got 1-2)
 | 
					 | 
				
			||||||
        raise ValueError(f"invalid HTTP status line: {d(status_line)}") from None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if version != b"HTTP/1.1":
 | 
					 | 
				
			||||||
        raise ValueError(f"unsupported HTTP version: {d(version)}")
 | 
					 | 
				
			||||||
    try:
 | 
					 | 
				
			||||||
        status_code = int(raw_status_code)
 | 
					 | 
				
			||||||
    except ValueError:  # invalid literal for int() with base 10
 | 
					 | 
				
			||||||
        raise ValueError(f"invalid HTTP status code: {d(raw_status_code)}") from None
 | 
					 | 
				
			||||||
    if not 100 <= status_code < 1000:
 | 
					 | 
				
			||||||
        raise ValueError(f"unsupported HTTP status code: {d(raw_status_code)}")
 | 
					 | 
				
			||||||
    if not _value_re.fullmatch(raw_reason):
 | 
					 | 
				
			||||||
        raise ValueError(f"invalid HTTP reason phrase: {d(raw_reason)}")
 | 
					 | 
				
			||||||
    reason = raw_reason.decode()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    headers = await read_headers(stream)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    return status_code, reason, headers
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
async def read_headers(stream: asyncio.StreamReader) -> Headers:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Read HTTP headers from ``stream``.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Non-ASCII characters are represented with surrogate escapes.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    # https://datatracker.ietf.org/doc/html/rfc7230#section-3.2
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # We don't attempt to support obsolete line folding.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    headers = Headers()
 | 
					 | 
				
			||||||
    for _ in range(MAX_NUM_HEADERS + 1):
 | 
					 | 
				
			||||||
        try:
 | 
					 | 
				
			||||||
            line = await read_line(stream)
 | 
					 | 
				
			||||||
        except EOFError as exc:
 | 
					 | 
				
			||||||
            raise EOFError("connection closed while reading HTTP headers") from exc
 | 
					 | 
				
			||||||
        if line == b"":
 | 
					 | 
				
			||||||
            break
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        try:
 | 
					 | 
				
			||||||
            raw_name, raw_value = line.split(b":", 1)
 | 
					 | 
				
			||||||
        except ValueError:  # not enough values to unpack (expected 2, got 1)
 | 
					 | 
				
			||||||
            raise ValueError(f"invalid HTTP header line: {d(line)}") from None
 | 
					 | 
				
			||||||
        if not _token_re.fullmatch(raw_name):
 | 
					 | 
				
			||||||
            raise ValueError(f"invalid HTTP header name: {d(raw_name)}")
 | 
					 | 
				
			||||||
        raw_value = raw_value.strip(b" \t")
 | 
					 | 
				
			||||||
        if not _value_re.fullmatch(raw_value):
 | 
					 | 
				
			||||||
            raise ValueError(f"invalid HTTP header value: {d(raw_value)}")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        name = raw_name.decode("ascii")  # guaranteed to be ASCII at this point
 | 
					 | 
				
			||||||
        value = raw_value.decode("ascii", "surrogateescape")
 | 
					 | 
				
			||||||
        headers[name] = value
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        raise SecurityError("too many HTTP headers")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    return headers
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
async def read_line(stream: asyncio.StreamReader) -> bytes:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Read a single line from ``stream``.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    CRLF is stripped from the return value.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    # Security: this is bounded by the StreamReader's limit (default = 32 KiB).
 | 
					 | 
				
			||||||
    line = await stream.readline()
 | 
					 | 
				
			||||||
    # Security: this guarantees header values are small (hard-coded = 8 KiB)
 | 
					 | 
				
			||||||
    if len(line) > MAX_LINE_LENGTH:
 | 
					 | 
				
			||||||
        raise SecurityError("line too long")
 | 
					 | 
				
			||||||
    # Not mandatory but safe - https://datatracker.ietf.org/doc/html/rfc7230#section-3.5
 | 
					 | 
				
			||||||
    if not line.endswith(b"\r\n"):
 | 
					 | 
				
			||||||
        raise EOFError("line without CRLF")
 | 
					 | 
				
			||||||
    return line[:-2]
 | 
					 | 
				
			||||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							@@ -1,732 +0,0 @@
 | 
				
			|||||||
from __future__ import annotations
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import enum
 | 
					 | 
				
			||||||
import logging
 | 
					 | 
				
			||||||
import uuid
 | 
					 | 
				
			||||||
from typing import Generator, Union
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from .exceptions import (
 | 
					 | 
				
			||||||
    ConnectionClosed,
 | 
					 | 
				
			||||||
    ConnectionClosedError,
 | 
					 | 
				
			||||||
    ConnectionClosedOK,
 | 
					 | 
				
			||||||
    InvalidState,
 | 
					 | 
				
			||||||
    PayloadTooBig,
 | 
					 | 
				
			||||||
    ProtocolError,
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
from .extensions import Extension
 | 
					 | 
				
			||||||
from .frames import (
 | 
					 | 
				
			||||||
    OK_CLOSE_CODES,
 | 
					 | 
				
			||||||
    OP_BINARY,
 | 
					 | 
				
			||||||
    OP_CLOSE,
 | 
					 | 
				
			||||||
    OP_CONT,
 | 
					 | 
				
			||||||
    OP_PING,
 | 
					 | 
				
			||||||
    OP_PONG,
 | 
					 | 
				
			||||||
    OP_TEXT,
 | 
					 | 
				
			||||||
    Close,
 | 
					 | 
				
			||||||
    CloseCode,
 | 
					 | 
				
			||||||
    Frame,
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
from .http11 import Request, Response
 | 
					 | 
				
			||||||
from .streams import StreamReader
 | 
					 | 
				
			||||||
from .typing import LoggerLike, Origin, Subprotocol
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
__all__ = [
 | 
					 | 
				
			||||||
    "Protocol",
 | 
					 | 
				
			||||||
    "Side",
 | 
					 | 
				
			||||||
    "State",
 | 
					 | 
				
			||||||
    "SEND_EOF",
 | 
					 | 
				
			||||||
]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# Change to Request | Response | Frame when dropping Python < 3.10.
 | 
					 | 
				
			||||||
Event = Union[Request, Response, Frame]
 | 
					 | 
				
			||||||
"""Events that :meth:`~Protocol.events_received` may return."""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class Side(enum.IntEnum):
 | 
					 | 
				
			||||||
    """A WebSocket connection is either a server or a client."""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    SERVER, CLIENT = range(2)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
SERVER = Side.SERVER
 | 
					 | 
				
			||||||
CLIENT = Side.CLIENT
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class State(enum.IntEnum):
 | 
					 | 
				
			||||||
    """A WebSocket connection is in one of these four states."""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    CONNECTING, OPEN, CLOSING, CLOSED = range(4)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
CONNECTING = State.CONNECTING
 | 
					 | 
				
			||||||
OPEN = State.OPEN
 | 
					 | 
				
			||||||
CLOSING = State.CLOSING
 | 
					 | 
				
			||||||
CLOSED = State.CLOSED
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
SEND_EOF = b""
 | 
					 | 
				
			||||||
"""Sentinel signaling that the TCP connection must be half-closed."""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class Protocol:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Sans-I/O implementation of a WebSocket connection.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Args:
 | 
					 | 
				
			||||||
        side: :attr:`~Side.CLIENT` or :attr:`~Side.SERVER`.
 | 
					 | 
				
			||||||
        state: Initial state of the WebSocket connection.
 | 
					 | 
				
			||||||
        max_size: Maximum size of incoming messages in bytes;
 | 
					 | 
				
			||||||
            :obj:`None` disables the limit.
 | 
					 | 
				
			||||||
        logger: Logger for this connection; depending on ``side``,
 | 
					 | 
				
			||||||
            defaults to ``logging.getLogger("websockets.client")``
 | 
					 | 
				
			||||||
            or ``logging.getLogger("websockets.server")``;
 | 
					 | 
				
			||||||
            see the :doc:`logging guide <../../topics/logging>` for details.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        side: Side,
 | 
					 | 
				
			||||||
        *,
 | 
					 | 
				
			||||||
        state: State = OPEN,
 | 
					 | 
				
			||||||
        max_size: int | None = 2**20,
 | 
					 | 
				
			||||||
        logger: LoggerLike | None = None,
 | 
					 | 
				
			||||||
    ) -> None:
 | 
					 | 
				
			||||||
        # Unique identifier. For logs.
 | 
					 | 
				
			||||||
        self.id: uuid.UUID = uuid.uuid4()
 | 
					 | 
				
			||||||
        """Unique identifier of the connection. Useful in logs."""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Logger or LoggerAdapter for this connection.
 | 
					 | 
				
			||||||
        if logger is None:
 | 
					 | 
				
			||||||
            logger = logging.getLogger(f"websockets.{side.name.lower()}")
 | 
					 | 
				
			||||||
        self.logger: LoggerLike = logger
 | 
					 | 
				
			||||||
        """Logger for this connection."""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Track if DEBUG is enabled. Shortcut logging calls if it isn't.
 | 
					 | 
				
			||||||
        self.debug = logger.isEnabledFor(logging.DEBUG)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Connection side. CLIENT or SERVER.
 | 
					 | 
				
			||||||
        self.side = side
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Connection state. Initially OPEN because subclasses handle CONNECTING.
 | 
					 | 
				
			||||||
        self.state = state
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Maximum size of incoming messages in bytes.
 | 
					 | 
				
			||||||
        self.max_size = max_size
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Current size of incoming message in bytes. Only set while reading a
 | 
					 | 
				
			||||||
        # fragmented message i.e. a data frames with the FIN bit not set.
 | 
					 | 
				
			||||||
        self.cur_size: int | None = None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # True while sending a fragmented message i.e. a data frames with the
 | 
					 | 
				
			||||||
        # FIN bit not set.
 | 
					 | 
				
			||||||
        self.expect_continuation_frame = False
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # WebSocket protocol parameters.
 | 
					 | 
				
			||||||
        self.origin: Origin | None = None
 | 
					 | 
				
			||||||
        self.extensions: list[Extension] = []
 | 
					 | 
				
			||||||
        self.subprotocol: Subprotocol | None = None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Close code and reason, set when a close frame is sent or received.
 | 
					 | 
				
			||||||
        self.close_rcvd: Close | None = None
 | 
					 | 
				
			||||||
        self.close_sent: Close | None = None
 | 
					 | 
				
			||||||
        self.close_rcvd_then_sent: bool | None = None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Track if an exception happened during the handshake.
 | 
					 | 
				
			||||||
        self.handshake_exc: Exception | None = None
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Exception to raise if the opening handshake failed.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        :obj:`None` if the opening handshake succeeded.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Track if send_eof() was called.
 | 
					 | 
				
			||||||
        self.eof_sent = False
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Parser state.
 | 
					 | 
				
			||||||
        self.reader = StreamReader()
 | 
					 | 
				
			||||||
        self.events: list[Event] = []
 | 
					 | 
				
			||||||
        self.writes: list[bytes] = []
 | 
					 | 
				
			||||||
        self.parser = self.parse()
 | 
					 | 
				
			||||||
        next(self.parser)  # start coroutine
 | 
					 | 
				
			||||||
        self.parser_exc: Exception | None = None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @property
 | 
					 | 
				
			||||||
    def state(self) -> State:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        State of the WebSocket connection.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Defined in 4.1, 4.2, 7.1.3, and 7.1.4 of :rfc:`6455`.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        return self._state
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @state.setter
 | 
					 | 
				
			||||||
    def state(self, state: State) -> None:
 | 
					 | 
				
			||||||
        if self.debug:
 | 
					 | 
				
			||||||
            self.logger.debug("= connection is %s", state.name)
 | 
					 | 
				
			||||||
        self._state = state
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @property
 | 
					 | 
				
			||||||
    def close_code(self) -> int | None:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        `WebSocket close code`_.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        .. _WebSocket close code:
 | 
					 | 
				
			||||||
            https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.5
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        :obj:`None` if the connection isn't closed yet.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        if self.state is not CLOSED:
 | 
					 | 
				
			||||||
            return None
 | 
					 | 
				
			||||||
        elif self.close_rcvd is None:
 | 
					 | 
				
			||||||
            return CloseCode.ABNORMAL_CLOSURE
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            return self.close_rcvd.code
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @property
 | 
					 | 
				
			||||||
    def close_reason(self) -> str | None:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        `WebSocket close reason`_.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        .. _WebSocket close reason:
 | 
					 | 
				
			||||||
            https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.6
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        :obj:`None` if the connection isn't closed yet.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        if self.state is not CLOSED:
 | 
					 | 
				
			||||||
            return None
 | 
					 | 
				
			||||||
        elif self.close_rcvd is None:
 | 
					 | 
				
			||||||
            return ""
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            return self.close_rcvd.reason
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @property
 | 
					 | 
				
			||||||
    def close_exc(self) -> ConnectionClosed:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Exception to raise when trying to interact with a closed connection.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Don't raise this exception while the connection :attr:`state`
 | 
					 | 
				
			||||||
        is :attr:`~websockets.protocol.State.CLOSING`; wait until
 | 
					 | 
				
			||||||
        it's :attr:`~websockets.protocol.State.CLOSED`.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Indeed, the exception includes the close code and reason, which are
 | 
					 | 
				
			||||||
        known only once the connection is closed.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Raises:
 | 
					 | 
				
			||||||
            AssertionError: If the connection isn't closed yet.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        assert self.state is CLOSED, "connection isn't closed yet"
 | 
					 | 
				
			||||||
        exc_type: type[ConnectionClosed]
 | 
					 | 
				
			||||||
        if (
 | 
					 | 
				
			||||||
            self.close_rcvd is not None
 | 
					 | 
				
			||||||
            and self.close_sent is not None
 | 
					 | 
				
			||||||
            and self.close_rcvd.code in OK_CLOSE_CODES
 | 
					 | 
				
			||||||
            and self.close_sent.code in OK_CLOSE_CODES
 | 
					 | 
				
			||||||
        ):
 | 
					 | 
				
			||||||
            exc_type = ConnectionClosedOK
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            exc_type = ConnectionClosedError
 | 
					 | 
				
			||||||
        exc: ConnectionClosed = exc_type(
 | 
					 | 
				
			||||||
            self.close_rcvd,
 | 
					 | 
				
			||||||
            self.close_sent,
 | 
					 | 
				
			||||||
            self.close_rcvd_then_sent,
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        # Chain to the exception raised in the parser, if any.
 | 
					 | 
				
			||||||
        exc.__cause__ = self.parser_exc
 | 
					 | 
				
			||||||
        return exc
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Public methods for receiving data.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def receive_data(self, data: bytes) -> None:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Receive data from the network.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        After calling this method:
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        - You must call :meth:`data_to_send` and send this data to the network.
 | 
					 | 
				
			||||||
        - You should call :meth:`events_received` and process resulting events.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Raises:
 | 
					 | 
				
			||||||
            EOFError: If :meth:`receive_eof` was called earlier.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        self.reader.feed_data(data)
 | 
					 | 
				
			||||||
        next(self.parser)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def receive_eof(self) -> None:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Receive the end of the data stream from the network.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        After calling this method:
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        - You must call :meth:`data_to_send` and send this data to the network;
 | 
					 | 
				
			||||||
          it will return ``[b""]``, signaling the end of the stream, or ``[]``.
 | 
					 | 
				
			||||||
        - You aren't expected to call :meth:`events_received`; it won't return
 | 
					 | 
				
			||||||
          any new events.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        :meth:`receive_eof` is idempotent.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        if self.reader.eof:
 | 
					 | 
				
			||||||
            return
 | 
					 | 
				
			||||||
        self.reader.feed_eof()
 | 
					 | 
				
			||||||
        next(self.parser)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Public methods for sending events.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def send_continuation(self, data: bytes, fin: bool) -> None:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Send a `Continuation frame`_.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        .. _Continuation frame:
 | 
					 | 
				
			||||||
            https://datatracker.ietf.org/doc/html/rfc6455#section-5.6
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Parameters:
 | 
					 | 
				
			||||||
            data: payload containing the same kind of data
 | 
					 | 
				
			||||||
                as the initial frame.
 | 
					 | 
				
			||||||
            fin: FIN bit; set it to :obj:`True` if this is the last frame
 | 
					 | 
				
			||||||
                of a fragmented message and to :obj:`False` otherwise.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Raises:
 | 
					 | 
				
			||||||
            ProtocolError: If a fragmented message isn't in progress.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        if not self.expect_continuation_frame:
 | 
					 | 
				
			||||||
            raise ProtocolError("unexpected continuation frame")
 | 
					 | 
				
			||||||
        if self._state is not OPEN:
 | 
					 | 
				
			||||||
            raise InvalidState(f"connection is {self.state.name.lower()}")
 | 
					 | 
				
			||||||
        self.expect_continuation_frame = not fin
 | 
					 | 
				
			||||||
        self.send_frame(Frame(OP_CONT, data, fin))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def send_text(self, data: bytes, fin: bool = True) -> None:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Send a `Text frame`_.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        .. _Text frame:
 | 
					 | 
				
			||||||
            https://datatracker.ietf.org/doc/html/rfc6455#section-5.6
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Parameters:
 | 
					 | 
				
			||||||
            data: payload containing text encoded with UTF-8.
 | 
					 | 
				
			||||||
            fin: FIN bit; set it to :obj:`False` if this is the first frame of
 | 
					 | 
				
			||||||
                a fragmented message.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Raises:
 | 
					 | 
				
			||||||
            ProtocolError: If a fragmented message is in progress.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        if self.expect_continuation_frame:
 | 
					 | 
				
			||||||
            raise ProtocolError("expected a continuation frame")
 | 
					 | 
				
			||||||
        if self._state is not OPEN:
 | 
					 | 
				
			||||||
            raise InvalidState(f"connection is {self.state.name.lower()}")
 | 
					 | 
				
			||||||
        self.expect_continuation_frame = not fin
 | 
					 | 
				
			||||||
        self.send_frame(Frame(OP_TEXT, data, fin))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def send_binary(self, data: bytes, fin: bool = True) -> None:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Send a `Binary frame`_.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        .. _Binary frame:
 | 
					 | 
				
			||||||
            https://datatracker.ietf.org/doc/html/rfc6455#section-5.6
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Parameters:
 | 
					 | 
				
			||||||
            data: payload containing arbitrary binary data.
 | 
					 | 
				
			||||||
            fin: FIN bit; set it to :obj:`False` if this is the first frame of
 | 
					 | 
				
			||||||
                a fragmented message.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Raises:
 | 
					 | 
				
			||||||
            ProtocolError: If a fragmented message is in progress.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        if self.expect_continuation_frame:
 | 
					 | 
				
			||||||
            raise ProtocolError("expected a continuation frame")
 | 
					 | 
				
			||||||
        if self._state is not OPEN:
 | 
					 | 
				
			||||||
            raise InvalidState(f"connection is {self.state.name.lower()}")
 | 
					 | 
				
			||||||
        self.expect_continuation_frame = not fin
 | 
					 | 
				
			||||||
        self.send_frame(Frame(OP_BINARY, data, fin))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def send_close(self, code: int | None = None, reason: str = "") -> None:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Send a `Close frame`_.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        .. _Close frame:
 | 
					 | 
				
			||||||
            https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.1
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Parameters:
 | 
					 | 
				
			||||||
            code: close code.
 | 
					 | 
				
			||||||
            reason: close reason.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Raises:
 | 
					 | 
				
			||||||
            ProtocolError: If the code isn't valid or if a reason is provided
 | 
					 | 
				
			||||||
                without a code.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        # While RFC 6455 doesn't rule out sending more than one close Frame,
 | 
					 | 
				
			||||||
        # websockets is conservative in what it sends and doesn't allow that.
 | 
					 | 
				
			||||||
        if self._state is not OPEN:
 | 
					 | 
				
			||||||
            raise InvalidState(f"connection is {self.state.name.lower()}")
 | 
					 | 
				
			||||||
        if code is None:
 | 
					 | 
				
			||||||
            if reason != "":
 | 
					 | 
				
			||||||
                raise ProtocolError("cannot send a reason without a code")
 | 
					 | 
				
			||||||
            close = Close(CloseCode.NO_STATUS_RCVD, "")
 | 
					 | 
				
			||||||
            data = b""
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            close = Close(code, reason)
 | 
					 | 
				
			||||||
            data = close.serialize()
 | 
					 | 
				
			||||||
        # 7.1.3. The WebSocket Closing Handshake is Started
 | 
					 | 
				
			||||||
        self.send_frame(Frame(OP_CLOSE, data))
 | 
					 | 
				
			||||||
        # Since the state is OPEN, no close frame was received yet.
 | 
					 | 
				
			||||||
        # As a consequence, self.close_rcvd_then_sent remains None.
 | 
					 | 
				
			||||||
        assert self.close_rcvd is None
 | 
					 | 
				
			||||||
        self.close_sent = close
 | 
					 | 
				
			||||||
        self.state = CLOSING
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def send_ping(self, data: bytes) -> None:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Send a `Ping frame`_.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        .. _Ping frame:
 | 
					 | 
				
			||||||
            https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.2
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Parameters:
 | 
					 | 
				
			||||||
            data: payload containing arbitrary binary data.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        # RFC 6455 allows control frames after starting the closing handshake.
 | 
					 | 
				
			||||||
        if self._state is not OPEN and self._state is not CLOSING:
 | 
					 | 
				
			||||||
            raise InvalidState(f"connection is {self.state.name.lower()}")
 | 
					 | 
				
			||||||
        self.send_frame(Frame(OP_PING, data))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def send_pong(self, data: bytes) -> None:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Send a `Pong frame`_.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        .. _Pong frame:
 | 
					 | 
				
			||||||
            https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.3
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Parameters:
 | 
					 | 
				
			||||||
            data: payload containing arbitrary binary data.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        # RFC 6455 allows control frames after starting the closing handshake.
 | 
					 | 
				
			||||||
        if self._state is not OPEN and self._state is not CLOSING:
 | 
					 | 
				
			||||||
            raise InvalidState(f"connection is {self.state.name.lower()}")
 | 
					 | 
				
			||||||
        self.send_frame(Frame(OP_PONG, data))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def fail(self, code: int, reason: str = "") -> None:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        `Fail the WebSocket connection`_.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        .. _Fail the WebSocket connection:
 | 
					 | 
				
			||||||
            https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.7
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Parameters:
 | 
					 | 
				
			||||||
            code: close code
 | 
					 | 
				
			||||||
            reason: close reason
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Raises:
 | 
					 | 
				
			||||||
            ProtocolError: If the code isn't valid.
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        # 7.1.7. Fail the WebSocket Connection
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Send a close frame when the state is OPEN (a close frame was already
 | 
					 | 
				
			||||||
        # sent if it's CLOSING), except when failing the connection because
 | 
					 | 
				
			||||||
        # of an error reading from or writing to the network.
 | 
					 | 
				
			||||||
        if self.state is OPEN:
 | 
					 | 
				
			||||||
            if code != CloseCode.ABNORMAL_CLOSURE:
 | 
					 | 
				
			||||||
                close = Close(code, reason)
 | 
					 | 
				
			||||||
                data = close.serialize()
 | 
					 | 
				
			||||||
                self.send_frame(Frame(OP_CLOSE, data))
 | 
					 | 
				
			||||||
                self.close_sent = close
 | 
					 | 
				
			||||||
                # If recv_messages() raised an exception upon receiving a close
 | 
					 | 
				
			||||||
                # frame but before echoing it, then close_rcvd is not None even
 | 
					 | 
				
			||||||
                # though the state is OPEN. This happens when the connection is
 | 
					 | 
				
			||||||
                # closed while receiving a fragmented message.
 | 
					 | 
				
			||||||
                if self.close_rcvd is not None:
 | 
					 | 
				
			||||||
                    self.close_rcvd_then_sent = True
 | 
					 | 
				
			||||||
                self.state = CLOSING
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # When failing the connection, a server closes the TCP connection
 | 
					 | 
				
			||||||
        # without waiting for the client to complete the handshake, while a
 | 
					 | 
				
			||||||
        # client waits for the server to close the TCP connection, possibly
 | 
					 | 
				
			||||||
        # after sending a close frame that the client will ignore.
 | 
					 | 
				
			||||||
        if self.side is SERVER and not self.eof_sent:
 | 
					 | 
				
			||||||
            self.send_eof()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # 7.1.7. Fail the WebSocket Connection "An endpoint MUST NOT continue
 | 
					 | 
				
			||||||
        # to attempt to process data(including a responding Close frame) from
 | 
					 | 
				
			||||||
        # the remote endpoint after being instructed to _Fail the WebSocket
 | 
					 | 
				
			||||||
        # Connection_."
 | 
					 | 
				
			||||||
        self.parser = self.discard()
 | 
					 | 
				
			||||||
        next(self.parser)  # start coroutine
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Public method for getting incoming events after receiving data.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def events_received(self) -> list[Event]:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Fetch events generated from data received from the network.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Call this method immediately after any of the ``receive_*()`` methods.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Process resulting events, likely by passing them to the application.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Returns:
 | 
					 | 
				
			||||||
            Events read from the connection.
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        events, self.events = self.events, []
 | 
					 | 
				
			||||||
        return events
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Public method for getting outgoing data after receiving data or sending events.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def data_to_send(self) -> list[bytes]:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Obtain data to send to the network.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Call this method immediately after any of the ``receive_*()``,
 | 
					 | 
				
			||||||
        ``send_*()``, or :meth:`fail` methods.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Write resulting data to the connection.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        The empty bytestring :data:`~websockets.protocol.SEND_EOF` signals
 | 
					 | 
				
			||||||
        the end of the data stream. When you receive it, half-close the TCP
 | 
					 | 
				
			||||||
        connection.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Returns:
 | 
					 | 
				
			||||||
            Data to write to the connection.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        writes, self.writes = self.writes, []
 | 
					 | 
				
			||||||
        return writes
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def close_expected(self) -> bool:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Tell if the TCP connection is expected to close soon.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Call this method immediately after any of the ``receive_*()``,
 | 
					 | 
				
			||||||
        ``send_close()``, or :meth:`fail` methods.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        If it returns :obj:`True`, schedule closing the TCP connection after a
 | 
					 | 
				
			||||||
        short timeout if the other side hasn't already closed it.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Returns:
 | 
					 | 
				
			||||||
            Whether the TCP connection is expected to close soon.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        # We expect a TCP close if and only if we sent a close frame:
 | 
					 | 
				
			||||||
        # * Normal closure: once we send a close frame, we expect a TCP close:
 | 
					 | 
				
			||||||
        #   server waits for client to complete the TCP closing handshake;
 | 
					 | 
				
			||||||
        #   client waits for server to initiate the TCP closing handshake.
 | 
					 | 
				
			||||||
        # * Abnormal closure: we always send a close frame and the same logic
 | 
					 | 
				
			||||||
        #   applies, except on EOFError where we don't send a close frame
 | 
					 | 
				
			||||||
        #   because we already received the TCP close, so we don't expect it.
 | 
					 | 
				
			||||||
        # We already got a TCP Close if and only if the state is CLOSED.
 | 
					 | 
				
			||||||
        return self.state is CLOSING or self.handshake_exc is not None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Private methods for receiving data.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def parse(self) -> Generator[None, None, None]:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Parse incoming data into frames.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        :meth:`receive_data` and :meth:`receive_eof` run this generator
 | 
					 | 
				
			||||||
        coroutine until it needs more data or reaches EOF.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        :meth:`parse` never raises an exception. Instead, it sets the
 | 
					 | 
				
			||||||
        :attr:`parser_exc` and yields control.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        try:
 | 
					 | 
				
			||||||
            while True:
 | 
					 | 
				
			||||||
                if (yield from self.reader.at_eof()):
 | 
					 | 
				
			||||||
                    if self.debug:
 | 
					 | 
				
			||||||
                        self.logger.debug("< EOF")
 | 
					 | 
				
			||||||
                    # If the WebSocket connection is closed cleanly, with a
 | 
					 | 
				
			||||||
                    # closing handhshake, recv_frame() substitutes parse()
 | 
					 | 
				
			||||||
                    # with discard(). This branch is reached only when the
 | 
					 | 
				
			||||||
                    # connection isn't closed cleanly.
 | 
					 | 
				
			||||||
                    raise EOFError("unexpected end of stream")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                if self.max_size is None:
 | 
					 | 
				
			||||||
                    max_size = None
 | 
					 | 
				
			||||||
                elif self.cur_size is None:
 | 
					 | 
				
			||||||
                    max_size = self.max_size
 | 
					 | 
				
			||||||
                else:
 | 
					 | 
				
			||||||
                    max_size = self.max_size - self.cur_size
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                # During a normal closure, execution ends here on the next
 | 
					 | 
				
			||||||
                # iteration of the loop after receiving a close frame. At
 | 
					 | 
				
			||||||
                # this point, recv_frame() replaced parse() by discard().
 | 
					 | 
				
			||||||
                frame = yield from Frame.parse(
 | 
					 | 
				
			||||||
                    self.reader.read_exact,
 | 
					 | 
				
			||||||
                    mask=self.side is SERVER,
 | 
					 | 
				
			||||||
                    max_size=max_size,
 | 
					 | 
				
			||||||
                    extensions=self.extensions,
 | 
					 | 
				
			||||||
                )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                if self.debug:
 | 
					 | 
				
			||||||
                    self.logger.debug("< %s", frame)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                self.recv_frame(frame)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        except ProtocolError as exc:
 | 
					 | 
				
			||||||
            self.fail(CloseCode.PROTOCOL_ERROR, str(exc))
 | 
					 | 
				
			||||||
            self.parser_exc = exc
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        except EOFError as exc:
 | 
					 | 
				
			||||||
            self.fail(CloseCode.ABNORMAL_CLOSURE, str(exc))
 | 
					 | 
				
			||||||
            self.parser_exc = exc
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        except UnicodeDecodeError as exc:
 | 
					 | 
				
			||||||
            self.fail(CloseCode.INVALID_DATA, f"{exc.reason} at position {exc.start}")
 | 
					 | 
				
			||||||
            self.parser_exc = exc
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        except PayloadTooBig as exc:
 | 
					 | 
				
			||||||
            self.fail(CloseCode.MESSAGE_TOO_BIG, str(exc))
 | 
					 | 
				
			||||||
            self.parser_exc = exc
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        except Exception as exc:
 | 
					 | 
				
			||||||
            self.logger.error("parser failed", exc_info=True)
 | 
					 | 
				
			||||||
            # Don't include exception details, which may be security-sensitive.
 | 
					 | 
				
			||||||
            self.fail(CloseCode.INTERNAL_ERROR)
 | 
					 | 
				
			||||||
            self.parser_exc = exc
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # During an abnormal closure, execution ends here after catching an
 | 
					 | 
				
			||||||
        # exception. At this point, fail() replaced parse() by discard().
 | 
					 | 
				
			||||||
        yield
 | 
					 | 
				
			||||||
        raise AssertionError("parse() shouldn't step after error")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def discard(self) -> Generator[None, None, None]:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Discard incoming data.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        This coroutine replaces :meth:`parse`:
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        - after receiving a close frame, during a normal closure (1.4);
 | 
					 | 
				
			||||||
        - after sending a close frame, during an abnormal closure (7.1.7).
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        # After the opening handshake completes, the server closes the TCP
 | 
					 | 
				
			||||||
        # connection in the same circumstances where discard() replaces parse().
 | 
					 | 
				
			||||||
        # The client closes it when it receives EOF from the server or times
 | 
					 | 
				
			||||||
        # out. (The latter case cannot be handled in this Sans-I/O layer.)
 | 
					 | 
				
			||||||
        assert (self.state == CONNECTING or self.side is SERVER) == (self.eof_sent)
 | 
					 | 
				
			||||||
        while not (yield from self.reader.at_eof()):
 | 
					 | 
				
			||||||
            self.reader.discard()
 | 
					 | 
				
			||||||
        if self.debug:
 | 
					 | 
				
			||||||
            self.logger.debug("< EOF")
 | 
					 | 
				
			||||||
        # A server closes the TCP connection immediately, while a client
 | 
					 | 
				
			||||||
        # waits for the server to close the TCP connection.
 | 
					 | 
				
			||||||
        if self.state != CONNECTING and self.side is CLIENT:
 | 
					 | 
				
			||||||
            self.send_eof()
 | 
					 | 
				
			||||||
        self.state = CLOSED
 | 
					 | 
				
			||||||
        # If discard() completes normally, execution ends here.
 | 
					 | 
				
			||||||
        yield
 | 
					 | 
				
			||||||
        # Once the reader reaches EOF, its feed_data/eof() methods raise an
 | 
					 | 
				
			||||||
        # error, so our receive_data/eof() methods don't step the generator.
 | 
					 | 
				
			||||||
        raise AssertionError("discard() shouldn't step after EOF")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def recv_frame(self, frame: Frame) -> None:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Process an incoming frame.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        if frame.opcode is OP_TEXT or frame.opcode is OP_BINARY:
 | 
					 | 
				
			||||||
            if self.cur_size is not None:
 | 
					 | 
				
			||||||
                raise ProtocolError("expected a continuation frame")
 | 
					 | 
				
			||||||
            if frame.fin:
 | 
					 | 
				
			||||||
                self.cur_size = None
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                self.cur_size = len(frame.data)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        elif frame.opcode is OP_CONT:
 | 
					 | 
				
			||||||
            if self.cur_size is None:
 | 
					 | 
				
			||||||
                raise ProtocolError("unexpected continuation frame")
 | 
					 | 
				
			||||||
            if frame.fin:
 | 
					 | 
				
			||||||
                self.cur_size = None
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                self.cur_size += len(frame.data)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        elif frame.opcode is OP_PING:
 | 
					 | 
				
			||||||
            # 5.5.2. Ping: "Upon receipt of a Ping frame, an endpoint MUST
 | 
					 | 
				
			||||||
            # send a Pong frame in response"
 | 
					 | 
				
			||||||
            pong_frame = Frame(OP_PONG, frame.data)
 | 
					 | 
				
			||||||
            self.send_frame(pong_frame)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        elif frame.opcode is OP_PONG:
 | 
					 | 
				
			||||||
            # 5.5.3 Pong: "A response to an unsolicited Pong frame is not
 | 
					 | 
				
			||||||
            # expected."
 | 
					 | 
				
			||||||
            pass
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        elif frame.opcode is OP_CLOSE:
 | 
					 | 
				
			||||||
            # 7.1.5.  The WebSocket Connection Close Code
 | 
					 | 
				
			||||||
            # 7.1.6.  The WebSocket Connection Close Reason
 | 
					 | 
				
			||||||
            self.close_rcvd = Close.parse(frame.data)
 | 
					 | 
				
			||||||
            if self.state is CLOSING:
 | 
					 | 
				
			||||||
                assert self.close_sent is not None
 | 
					 | 
				
			||||||
                self.close_rcvd_then_sent = False
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            if self.cur_size is not None:
 | 
					 | 
				
			||||||
                raise ProtocolError("incomplete fragmented message")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            # 5.5.1 Close: "If an endpoint receives a Close frame and did
 | 
					 | 
				
			||||||
            # not previously send a Close frame, the endpoint MUST send a
 | 
					 | 
				
			||||||
            # Close frame in response. (When sending a Close frame in
 | 
					 | 
				
			||||||
            # response, the endpoint typically echos the status code it
 | 
					 | 
				
			||||||
            # received.)"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            if self.state is OPEN:
 | 
					 | 
				
			||||||
                # Echo the original data instead of re-serializing it with
 | 
					 | 
				
			||||||
                # Close.serialize() because that fails when the close frame
 | 
					 | 
				
			||||||
                # is empty and Close.parse() synthesizes a 1005 close code.
 | 
					 | 
				
			||||||
                # The rest is identical to send_close().
 | 
					 | 
				
			||||||
                self.send_frame(Frame(OP_CLOSE, frame.data))
 | 
					 | 
				
			||||||
                self.close_sent = self.close_rcvd
 | 
					 | 
				
			||||||
                self.close_rcvd_then_sent = True
 | 
					 | 
				
			||||||
                self.state = CLOSING
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            # 7.1.2. Start the WebSocket Closing Handshake: "Once an
 | 
					 | 
				
			||||||
            # endpoint has both sent and received a Close control frame,
 | 
					 | 
				
			||||||
            # that endpoint SHOULD _Close the WebSocket Connection_"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            # A server closes the TCP connection immediately, while a client
 | 
					 | 
				
			||||||
            # waits for the server to close the TCP connection.
 | 
					 | 
				
			||||||
            if self.side is SERVER:
 | 
					 | 
				
			||||||
                self.send_eof()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            # 1.4. Closing Handshake: "after receiving a control frame
 | 
					 | 
				
			||||||
            # indicating the connection should be closed, a peer discards
 | 
					 | 
				
			||||||
            # any further data received."
 | 
					 | 
				
			||||||
            # RFC 6455 allows reading Ping and Pong frames after a Close frame.
 | 
					 | 
				
			||||||
            # However, that doesn't seem useful; websockets doesn't support it.
 | 
					 | 
				
			||||||
            self.parser = self.discard()
 | 
					 | 
				
			||||||
            next(self.parser)  # start coroutine
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            # This can't happen because Frame.parse() validates opcodes.
 | 
					 | 
				
			||||||
            raise AssertionError(f"unexpected opcode: {frame.opcode:02x}")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.events.append(frame)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Private methods for sending events.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def send_frame(self, frame: Frame) -> None:
 | 
					 | 
				
			||||||
        if self.debug:
 | 
					 | 
				
			||||||
            self.logger.debug("> %s", frame)
 | 
					 | 
				
			||||||
        self.writes.append(
 | 
					 | 
				
			||||||
            frame.serialize(
 | 
					 | 
				
			||||||
                mask=self.side is CLIENT,
 | 
					 | 
				
			||||||
                extensions=self.extensions,
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def send_eof(self) -> None:
 | 
					 | 
				
			||||||
        assert not self.eof_sent
 | 
					 | 
				
			||||||
        self.eof_sent = True
 | 
					 | 
				
			||||||
        if self.debug:
 | 
					 | 
				
			||||||
            self.logger.debug("> EOF")
 | 
					 | 
				
			||||||
        self.writes.append(SEND_EOF)
 | 
					 | 
				
			||||||
@@ -1,587 +0,0 @@
 | 
				
			|||||||
from __future__ import annotations
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import base64
 | 
					 | 
				
			||||||
import binascii
 | 
					 | 
				
			||||||
import email.utils
 | 
					 | 
				
			||||||
import http
 | 
					 | 
				
			||||||
import warnings
 | 
					 | 
				
			||||||
from typing import Any, Callable, Generator, Sequence, cast
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from .datastructures import Headers, MultipleValuesError
 | 
					 | 
				
			||||||
from .exceptions import (
 | 
					 | 
				
			||||||
    InvalidHandshake,
 | 
					 | 
				
			||||||
    InvalidHeader,
 | 
					 | 
				
			||||||
    InvalidHeaderValue,
 | 
					 | 
				
			||||||
    InvalidOrigin,
 | 
					 | 
				
			||||||
    InvalidStatus,
 | 
					 | 
				
			||||||
    InvalidUpgrade,
 | 
					 | 
				
			||||||
    NegotiationError,
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
from .extensions import Extension, ServerExtensionFactory
 | 
					 | 
				
			||||||
from .headers import (
 | 
					 | 
				
			||||||
    build_extension,
 | 
					 | 
				
			||||||
    parse_connection,
 | 
					 | 
				
			||||||
    parse_extension,
 | 
					 | 
				
			||||||
    parse_subprotocol,
 | 
					 | 
				
			||||||
    parse_upgrade,
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
from .http11 import Request, Response
 | 
					 | 
				
			||||||
from .protocol import CONNECTING, OPEN, SERVER, Protocol, State
 | 
					 | 
				
			||||||
from .typing import (
 | 
					 | 
				
			||||||
    ConnectionOption,
 | 
					 | 
				
			||||||
    ExtensionHeader,
 | 
					 | 
				
			||||||
    LoggerLike,
 | 
					 | 
				
			||||||
    Origin,
 | 
					 | 
				
			||||||
    StatusLike,
 | 
					 | 
				
			||||||
    Subprotocol,
 | 
					 | 
				
			||||||
    UpgradeProtocol,
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
from .utils import accept_key
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# See #940 for why lazy_import isn't used here for backwards compatibility.
 | 
					 | 
				
			||||||
# See #1400 for why listing compatibility imports in __all__ helps PyCharm.
 | 
					 | 
				
			||||||
from .legacy.server import *  # isort:skip  # noqa: I001
 | 
					 | 
				
			||||||
from .legacy.server import __all__ as legacy__all__
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
__all__ = ["ServerProtocol"] + legacy__all__
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class ServerProtocol(Protocol):
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Sans-I/O implementation of a WebSocket server connection.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Args:
 | 
					 | 
				
			||||||
        origins: Acceptable values of the ``Origin`` header; include
 | 
					 | 
				
			||||||
            :obj:`None` in the list if the lack of an origin is acceptable.
 | 
					 | 
				
			||||||
            This is useful for defending against Cross-Site WebSocket
 | 
					 | 
				
			||||||
            Hijacking attacks.
 | 
					 | 
				
			||||||
        extensions: List of supported extensions, in order in which they
 | 
					 | 
				
			||||||
            should be tried.
 | 
					 | 
				
			||||||
        subprotocols: List of supported subprotocols, in order of decreasing
 | 
					 | 
				
			||||||
            preference.
 | 
					 | 
				
			||||||
        select_subprotocol: Callback for selecting a subprotocol among
 | 
					 | 
				
			||||||
            those supported by the client and the server. It has the same
 | 
					 | 
				
			||||||
            signature as the :meth:`select_subprotocol` method, including a
 | 
					 | 
				
			||||||
            :class:`ServerProtocol` instance as first argument.
 | 
					 | 
				
			||||||
        state: Initial state of the WebSocket connection.
 | 
					 | 
				
			||||||
        max_size: Maximum size of incoming messages in bytes;
 | 
					 | 
				
			||||||
            :obj:`None` disables the limit.
 | 
					 | 
				
			||||||
        logger: Logger for this connection;
 | 
					 | 
				
			||||||
            defaults to ``logging.getLogger("websockets.server")``;
 | 
					 | 
				
			||||||
            see the :doc:`logging guide <../../topics/logging>` for details.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        *,
 | 
					 | 
				
			||||||
        origins: Sequence[Origin | None] | None = None,
 | 
					 | 
				
			||||||
        extensions: Sequence[ServerExtensionFactory] | None = None,
 | 
					 | 
				
			||||||
        subprotocols: Sequence[Subprotocol] | None = None,
 | 
					 | 
				
			||||||
        select_subprotocol: (
 | 
					 | 
				
			||||||
            Callable[
 | 
					 | 
				
			||||||
                [ServerProtocol, Sequence[Subprotocol]],
 | 
					 | 
				
			||||||
                Subprotocol | None,
 | 
					 | 
				
			||||||
            ]
 | 
					 | 
				
			||||||
            | None
 | 
					 | 
				
			||||||
        ) = None,
 | 
					 | 
				
			||||||
        state: State = CONNECTING,
 | 
					 | 
				
			||||||
        max_size: int | None = 2**20,
 | 
					 | 
				
			||||||
        logger: LoggerLike | None = None,
 | 
					 | 
				
			||||||
    ) -> None:
 | 
					 | 
				
			||||||
        super().__init__(
 | 
					 | 
				
			||||||
            side=SERVER,
 | 
					 | 
				
			||||||
            state=state,
 | 
					 | 
				
			||||||
            max_size=max_size,
 | 
					 | 
				
			||||||
            logger=logger,
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        self.origins = origins
 | 
					 | 
				
			||||||
        self.available_extensions = extensions
 | 
					 | 
				
			||||||
        self.available_subprotocols = subprotocols
 | 
					 | 
				
			||||||
        if select_subprotocol is not None:
 | 
					 | 
				
			||||||
            # Bind select_subprotocol then shadow self.select_subprotocol.
 | 
					 | 
				
			||||||
            # Use setattr to work around https://github.com/python/mypy/issues/2427.
 | 
					 | 
				
			||||||
            setattr(
 | 
					 | 
				
			||||||
                self,
 | 
					 | 
				
			||||||
                "select_subprotocol",
 | 
					 | 
				
			||||||
                select_subprotocol.__get__(self, self.__class__),
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def accept(self, request: Request) -> Response:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Create a handshake response to accept the connection.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        If the handshake request is valid and the handshake successful,
 | 
					 | 
				
			||||||
        :meth:`accept` returns an HTTP response with status code 101.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Else, it returns an HTTP response with another status code. This rejects
 | 
					 | 
				
			||||||
        the connection, like :meth:`reject` would.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        You must send the handshake response with :meth:`send_response`.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        You may modify the response before sending it, typically by adding HTTP
 | 
					 | 
				
			||||||
        headers.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Args:
 | 
					 | 
				
			||||||
            request: WebSocket handshake request received from the client.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Returns:
 | 
					 | 
				
			||||||
            WebSocket handshake response or HTTP response to send to the client.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        try:
 | 
					 | 
				
			||||||
            (
 | 
					 | 
				
			||||||
                accept_header,
 | 
					 | 
				
			||||||
                extensions_header,
 | 
					 | 
				
			||||||
                protocol_header,
 | 
					 | 
				
			||||||
            ) = self.process_request(request)
 | 
					 | 
				
			||||||
        except InvalidOrigin as exc:
 | 
					 | 
				
			||||||
            request._exception = exc
 | 
					 | 
				
			||||||
            self.handshake_exc = exc
 | 
					 | 
				
			||||||
            if self.debug:
 | 
					 | 
				
			||||||
                self.logger.debug("! invalid origin", exc_info=True)
 | 
					 | 
				
			||||||
            return self.reject(
 | 
					 | 
				
			||||||
                http.HTTPStatus.FORBIDDEN,
 | 
					 | 
				
			||||||
                f"Failed to open a WebSocket connection: {exc}.\n",
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
        except InvalidUpgrade as exc:
 | 
					 | 
				
			||||||
            request._exception = exc
 | 
					 | 
				
			||||||
            self.handshake_exc = exc
 | 
					 | 
				
			||||||
            if self.debug:
 | 
					 | 
				
			||||||
                self.logger.debug("! invalid upgrade", exc_info=True)
 | 
					 | 
				
			||||||
            response = self.reject(
 | 
					 | 
				
			||||||
                http.HTTPStatus.UPGRADE_REQUIRED,
 | 
					 | 
				
			||||||
                (
 | 
					 | 
				
			||||||
                    f"Failed to open a WebSocket connection: {exc}.\n"
 | 
					 | 
				
			||||||
                    f"\n"
 | 
					 | 
				
			||||||
                    f"You cannot access a WebSocket server directly "
 | 
					 | 
				
			||||||
                    f"with a browser. You need a WebSocket client.\n"
 | 
					 | 
				
			||||||
                ),
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
            response.headers["Upgrade"] = "websocket"
 | 
					 | 
				
			||||||
            return response
 | 
					 | 
				
			||||||
        except InvalidHandshake as exc:
 | 
					 | 
				
			||||||
            request._exception = exc
 | 
					 | 
				
			||||||
            self.handshake_exc = exc
 | 
					 | 
				
			||||||
            if self.debug:
 | 
					 | 
				
			||||||
                self.logger.debug("! invalid handshake", exc_info=True)
 | 
					 | 
				
			||||||
            exc_chain = cast(BaseException, exc)
 | 
					 | 
				
			||||||
            exc_str = f"{exc_chain}"
 | 
					 | 
				
			||||||
            while exc_chain.__cause__ is not None:
 | 
					 | 
				
			||||||
                exc_chain = exc_chain.__cause__
 | 
					 | 
				
			||||||
                exc_str += f"; {exc_chain}"
 | 
					 | 
				
			||||||
            return self.reject(
 | 
					 | 
				
			||||||
                http.HTTPStatus.BAD_REQUEST,
 | 
					 | 
				
			||||||
                f"Failed to open a WebSocket connection: {exc_str}.\n",
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
        except Exception as exc:
 | 
					 | 
				
			||||||
            # Handle exceptions raised by user-provided select_subprotocol and
 | 
					 | 
				
			||||||
            # unexpected errors.
 | 
					 | 
				
			||||||
            request._exception = exc
 | 
					 | 
				
			||||||
            self.handshake_exc = exc
 | 
					 | 
				
			||||||
            self.logger.error("opening handshake failed", exc_info=True)
 | 
					 | 
				
			||||||
            return self.reject(
 | 
					 | 
				
			||||||
                http.HTTPStatus.INTERNAL_SERVER_ERROR,
 | 
					 | 
				
			||||||
                (
 | 
					 | 
				
			||||||
                    "Failed to open a WebSocket connection.\n"
 | 
					 | 
				
			||||||
                    "See server log for more information.\n"
 | 
					 | 
				
			||||||
                ),
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        headers = Headers()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        headers["Date"] = email.utils.formatdate(usegmt=True)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        headers["Upgrade"] = "websocket"
 | 
					 | 
				
			||||||
        headers["Connection"] = "Upgrade"
 | 
					 | 
				
			||||||
        headers["Sec-WebSocket-Accept"] = accept_header
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if extensions_header is not None:
 | 
					 | 
				
			||||||
            headers["Sec-WebSocket-Extensions"] = extensions_header
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if protocol_header is not None:
 | 
					 | 
				
			||||||
            headers["Sec-WebSocket-Protocol"] = protocol_header
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return Response(101, "Switching Protocols", headers)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def process_request(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        request: Request,
 | 
					 | 
				
			||||||
    ) -> tuple[str, str | None, str | None]:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Check a handshake request and negotiate extensions and subprotocol.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        This function doesn't verify that the request is an HTTP/1.1 or higher
 | 
					 | 
				
			||||||
        GET request and doesn't check the ``Host`` header. These controls are
 | 
					 | 
				
			||||||
        usually performed earlier in the HTTP request handling code. They're
 | 
					 | 
				
			||||||
        the responsibility of the caller.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Args:
 | 
					 | 
				
			||||||
            request: WebSocket handshake request received from the client.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Returns:
 | 
					 | 
				
			||||||
            ``Sec-WebSocket-Accept``, ``Sec-WebSocket-Extensions``, and
 | 
					 | 
				
			||||||
            ``Sec-WebSocket-Protocol`` headers for the handshake response.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Raises:
 | 
					 | 
				
			||||||
            InvalidHandshake: If the handshake request is invalid;
 | 
					 | 
				
			||||||
                then the server must return 400 Bad Request error.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        headers = request.headers
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        connection: list[ConnectionOption] = sum(
 | 
					 | 
				
			||||||
            [parse_connection(value) for value in headers.get_all("Connection")], []
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if not any(value.lower() == "upgrade" for value in connection):
 | 
					 | 
				
			||||||
            raise InvalidUpgrade(
 | 
					 | 
				
			||||||
                "Connection", ", ".join(connection) if connection else None
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        upgrade: list[UpgradeProtocol] = sum(
 | 
					 | 
				
			||||||
            [parse_upgrade(value) for value in headers.get_all("Upgrade")], []
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # For compatibility with non-strict implementations, ignore case when
 | 
					 | 
				
			||||||
        # checking the Upgrade header. The RFC always uses "websocket", except
 | 
					 | 
				
			||||||
        # in section 11.2. (IANA registration) where it uses "WebSocket".
 | 
					 | 
				
			||||||
        if not (len(upgrade) == 1 and upgrade[0].lower() == "websocket"):
 | 
					 | 
				
			||||||
            raise InvalidUpgrade("Upgrade", ", ".join(upgrade) if upgrade else None)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        try:
 | 
					 | 
				
			||||||
            key = headers["Sec-WebSocket-Key"]
 | 
					 | 
				
			||||||
        except KeyError as exc:
 | 
					 | 
				
			||||||
            raise InvalidHeader("Sec-WebSocket-Key") from exc
 | 
					 | 
				
			||||||
        except MultipleValuesError as exc:
 | 
					 | 
				
			||||||
            raise InvalidHeader("Sec-WebSocket-Key", "multiple values") from exc
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        try:
 | 
					 | 
				
			||||||
            raw_key = base64.b64decode(key.encode(), validate=True)
 | 
					 | 
				
			||||||
        except binascii.Error as exc:
 | 
					 | 
				
			||||||
            raise InvalidHeaderValue("Sec-WebSocket-Key", key) from exc
 | 
					 | 
				
			||||||
        if len(raw_key) != 16:
 | 
					 | 
				
			||||||
            raise InvalidHeaderValue("Sec-WebSocket-Key", key)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        try:
 | 
					 | 
				
			||||||
            version = headers["Sec-WebSocket-Version"]
 | 
					 | 
				
			||||||
        except KeyError as exc:
 | 
					 | 
				
			||||||
            raise InvalidHeader("Sec-WebSocket-Version") from exc
 | 
					 | 
				
			||||||
        except MultipleValuesError as exc:
 | 
					 | 
				
			||||||
            raise InvalidHeader("Sec-WebSocket-Version", "multiple values") from exc
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if version != "13":
 | 
					 | 
				
			||||||
            raise InvalidHeaderValue("Sec-WebSocket-Version", version)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        accept_header = accept_key(key)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.origin = self.process_origin(headers)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        extensions_header, self.extensions = self.process_extensions(headers)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        protocol_header = self.subprotocol = self.process_subprotocol(headers)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return (
 | 
					 | 
				
			||||||
            accept_header,
 | 
					 | 
				
			||||||
            extensions_header,
 | 
					 | 
				
			||||||
            protocol_header,
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def process_origin(self, headers: Headers) -> Origin | None:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Handle the Origin HTTP request header.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Args:
 | 
					 | 
				
			||||||
            headers: WebSocket handshake request headers.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Returns:
 | 
					 | 
				
			||||||
           origin, if it is acceptable.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Raises:
 | 
					 | 
				
			||||||
            InvalidHandshake: If the Origin header is invalid.
 | 
					 | 
				
			||||||
            InvalidOrigin: If the origin isn't acceptable.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        # "The user agent MUST NOT include more than one Origin header field"
 | 
					 | 
				
			||||||
        # per https://datatracker.ietf.org/doc/html/rfc6454#section-7.3.
 | 
					 | 
				
			||||||
        try:
 | 
					 | 
				
			||||||
            origin = headers.get("Origin")
 | 
					 | 
				
			||||||
        except MultipleValuesError as exc:
 | 
					 | 
				
			||||||
            raise InvalidHeader("Origin", "multiple values") from exc
 | 
					 | 
				
			||||||
        if origin is not None:
 | 
					 | 
				
			||||||
            origin = cast(Origin, origin)
 | 
					 | 
				
			||||||
        if self.origins is not None:
 | 
					 | 
				
			||||||
            if origin not in self.origins:
 | 
					 | 
				
			||||||
                raise InvalidOrigin(origin)
 | 
					 | 
				
			||||||
        return origin
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def process_extensions(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        headers: Headers,
 | 
					 | 
				
			||||||
    ) -> tuple[str | None, list[Extension]]:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Handle the Sec-WebSocket-Extensions HTTP request header.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Accept or reject each extension proposed in the client request.
 | 
					 | 
				
			||||||
        Negotiate parameters for accepted extensions.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Per :rfc:`6455`, negotiation rules are defined by the specification of
 | 
					 | 
				
			||||||
        each extension.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        To provide this level of flexibility, for each extension proposed by
 | 
					 | 
				
			||||||
        the client, we check for a match with each extension available in the
 | 
					 | 
				
			||||||
        server configuration. If no match is found, the extension is ignored.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        If several variants of the same extension are proposed by the client,
 | 
					 | 
				
			||||||
        it may be accepted several times, which won't make sense in general.
 | 
					 | 
				
			||||||
        Extensions must implement their own requirements. For this purpose,
 | 
					 | 
				
			||||||
        the list of previously accepted extensions is provided.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        This process doesn't allow the server to reorder extensions. It can
 | 
					 | 
				
			||||||
        only select a subset of the extensions proposed by the client.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Other requirements, for example related to mandatory extensions or the
 | 
					 | 
				
			||||||
        order of extensions, may be implemented by overriding this method.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Args:
 | 
					 | 
				
			||||||
            headers: WebSocket handshake request headers.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Returns:
 | 
					 | 
				
			||||||
            ``Sec-WebSocket-Extensions`` HTTP response header and list of
 | 
					 | 
				
			||||||
            accepted extensions.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Raises:
 | 
					 | 
				
			||||||
            InvalidHandshake: If the Sec-WebSocket-Extensions header is invalid.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        response_header_value: str | None = None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        extension_headers: list[ExtensionHeader] = []
 | 
					 | 
				
			||||||
        accepted_extensions: list[Extension] = []
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        header_values = headers.get_all("Sec-WebSocket-Extensions")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if header_values and self.available_extensions:
 | 
					 | 
				
			||||||
            parsed_header_values: list[ExtensionHeader] = sum(
 | 
					 | 
				
			||||||
                [parse_extension(header_value) for header_value in header_values], []
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            for name, request_params in parsed_header_values:
 | 
					 | 
				
			||||||
                for ext_factory in self.available_extensions:
 | 
					 | 
				
			||||||
                    # Skip non-matching extensions based on their name.
 | 
					 | 
				
			||||||
                    if ext_factory.name != name:
 | 
					 | 
				
			||||||
                        continue
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                    # Skip non-matching extensions based on their params.
 | 
					 | 
				
			||||||
                    try:
 | 
					 | 
				
			||||||
                        response_params, extension = ext_factory.process_request_params(
 | 
					 | 
				
			||||||
                            request_params, accepted_extensions
 | 
					 | 
				
			||||||
                        )
 | 
					 | 
				
			||||||
                    except NegotiationError:
 | 
					 | 
				
			||||||
                        continue
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                    # Add matching extension to the final list.
 | 
					 | 
				
			||||||
                    extension_headers.append((name, response_params))
 | 
					 | 
				
			||||||
                    accepted_extensions.append(extension)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                    # Break out of the loop once we have a match.
 | 
					 | 
				
			||||||
                    break
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                # If we didn't break from the loop, no extension in our list
 | 
					 | 
				
			||||||
                # matched what the client sent. The extension is declined.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Serialize extension header.
 | 
					 | 
				
			||||||
        if extension_headers:
 | 
					 | 
				
			||||||
            response_header_value = build_extension(extension_headers)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return response_header_value, accepted_extensions
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def process_subprotocol(self, headers: Headers) -> Subprotocol | None:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Handle the Sec-WebSocket-Protocol HTTP request header.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Args:
 | 
					 | 
				
			||||||
            headers: WebSocket handshake request headers.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Returns:
 | 
					 | 
				
			||||||
           Subprotocol, if one was selected; this is also the value of the
 | 
					 | 
				
			||||||
           ``Sec-WebSocket-Protocol`` response header.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Raises:
 | 
					 | 
				
			||||||
            InvalidHandshake: If the Sec-WebSocket-Subprotocol header is invalid.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        subprotocols: Sequence[Subprotocol] = sum(
 | 
					 | 
				
			||||||
            [
 | 
					 | 
				
			||||||
                parse_subprotocol(header_value)
 | 
					 | 
				
			||||||
                for header_value in headers.get_all("Sec-WebSocket-Protocol")
 | 
					 | 
				
			||||||
            ],
 | 
					 | 
				
			||||||
            [],
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return self.select_subprotocol(subprotocols)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def select_subprotocol(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        subprotocols: Sequence[Subprotocol],
 | 
					 | 
				
			||||||
    ) -> Subprotocol | None:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Pick a subprotocol among those offered by the client.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        If several subprotocols are supported by both the client and the server,
 | 
					 | 
				
			||||||
        pick the first one in the list declared the server.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        If the server doesn't support any subprotocols, continue without a
 | 
					 | 
				
			||||||
        subprotocol, regardless of what the client offers.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        If the server supports at least one subprotocol and the client doesn't
 | 
					 | 
				
			||||||
        offer any, abort the handshake with an HTTP 400 error.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        You provide a ``select_subprotocol`` argument to :class:`ServerProtocol`
 | 
					 | 
				
			||||||
        to override this logic. For example, you could accept the connection
 | 
					 | 
				
			||||||
        even if client doesn't offer a subprotocol, rather than reject it.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Here's how to negotiate the ``chat`` subprotocol if the client supports
 | 
					 | 
				
			||||||
        it and continue without a subprotocol otherwise::
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            def select_subprotocol(protocol, subprotocols):
 | 
					 | 
				
			||||||
                if "chat" in subprotocols:
 | 
					 | 
				
			||||||
                    return "chat"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Args:
 | 
					 | 
				
			||||||
            subprotocols: List of subprotocols offered by the client.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Returns:
 | 
					 | 
				
			||||||
            Selected subprotocol, if a common subprotocol was found.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            :obj:`None` to continue without a subprotocol.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Raises:
 | 
					 | 
				
			||||||
            NegotiationError: Custom implementations may raise this exception
 | 
					 | 
				
			||||||
                to abort the handshake with an HTTP 400 error.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        # Server doesn't offer any subprotocols.
 | 
					 | 
				
			||||||
        if not self.available_subprotocols:  # None or empty list
 | 
					 | 
				
			||||||
            return None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Server offers at least one subprotocol but client doesn't offer any.
 | 
					 | 
				
			||||||
        if not subprotocols:
 | 
					 | 
				
			||||||
            raise NegotiationError("missing subprotocol")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Server and client both offer subprotocols. Look for a shared one.
 | 
					 | 
				
			||||||
        proposed_subprotocols = set(subprotocols)
 | 
					 | 
				
			||||||
        for subprotocol in self.available_subprotocols:
 | 
					 | 
				
			||||||
            if subprotocol in proposed_subprotocols:
 | 
					 | 
				
			||||||
                return subprotocol
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # No common subprotocol was found.
 | 
					 | 
				
			||||||
        raise NegotiationError(
 | 
					 | 
				
			||||||
            "invalid subprotocol; expected one of "
 | 
					 | 
				
			||||||
            + ", ".join(self.available_subprotocols)
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def reject(self, status: StatusLike, text: str) -> Response:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Create a handshake response to reject the connection.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        A short plain text response is the best fallback when failing to
 | 
					 | 
				
			||||||
        establish a WebSocket connection.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        You must send the handshake response with :meth:`send_response`.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        You may modify the response before sending it, for example by changing
 | 
					 | 
				
			||||||
        HTTP headers.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Args:
 | 
					 | 
				
			||||||
            status: HTTP status code.
 | 
					 | 
				
			||||||
            text: HTTP response body; it will be encoded to UTF-8.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Returns:
 | 
					 | 
				
			||||||
            HTTP response to send to the client.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        # If a user passes an int instead of a HTTPStatus, fix it automatically.
 | 
					 | 
				
			||||||
        status = http.HTTPStatus(status)
 | 
					 | 
				
			||||||
        body = text.encode()
 | 
					 | 
				
			||||||
        headers = Headers(
 | 
					 | 
				
			||||||
            [
 | 
					 | 
				
			||||||
                ("Date", email.utils.formatdate(usegmt=True)),
 | 
					 | 
				
			||||||
                ("Connection", "close"),
 | 
					 | 
				
			||||||
                ("Content-Length", str(len(body))),
 | 
					 | 
				
			||||||
                ("Content-Type", "text/plain; charset=utf-8"),
 | 
					 | 
				
			||||||
            ]
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        return Response(status.value, status.phrase, headers, body)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def send_response(self, response: Response) -> None:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Send a handshake response to the client.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Args:
 | 
					 | 
				
			||||||
            response: WebSocket handshake response event to send.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        if self.debug:
 | 
					 | 
				
			||||||
            code, phrase = response.status_code, response.reason_phrase
 | 
					 | 
				
			||||||
            self.logger.debug("> HTTP/1.1 %d %s", code, phrase)
 | 
					 | 
				
			||||||
            for key, value in response.headers.raw_items():
 | 
					 | 
				
			||||||
                self.logger.debug("> %s: %s", key, value)
 | 
					 | 
				
			||||||
            if response.body is not None:
 | 
					 | 
				
			||||||
                self.logger.debug("> [body] (%d bytes)", len(response.body))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.writes.append(response.serialize())
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if response.status_code == 101:
 | 
					 | 
				
			||||||
            assert self.state is CONNECTING
 | 
					 | 
				
			||||||
            self.state = OPEN
 | 
					 | 
				
			||||||
            self.logger.info("connection open")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            # handshake_exc may be already set if accept() encountered an error.
 | 
					 | 
				
			||||||
            # If the connection isn't open, set handshake_exc to guarantee that
 | 
					 | 
				
			||||||
            # handshake_exc is None if and only if opening handshake succeeded.
 | 
					 | 
				
			||||||
            if self.handshake_exc is None:
 | 
					 | 
				
			||||||
                self.handshake_exc = InvalidStatus(response)
 | 
					 | 
				
			||||||
            self.logger.info(
 | 
					 | 
				
			||||||
                "connection rejected (%d %s)",
 | 
					 | 
				
			||||||
                response.status_code,
 | 
					 | 
				
			||||||
                response.reason_phrase,
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            self.send_eof()
 | 
					 | 
				
			||||||
            self.parser = self.discard()
 | 
					 | 
				
			||||||
            next(self.parser)  # start coroutine
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def parse(self) -> Generator[None, None, None]:
 | 
					 | 
				
			||||||
        if self.state is CONNECTING:
 | 
					 | 
				
			||||||
            try:
 | 
					 | 
				
			||||||
                request = yield from Request.parse(
 | 
					 | 
				
			||||||
                    self.reader.read_line,
 | 
					 | 
				
			||||||
                )
 | 
					 | 
				
			||||||
            except Exception as exc:
 | 
					 | 
				
			||||||
                self.handshake_exc = exc
 | 
					 | 
				
			||||||
                self.send_eof()
 | 
					 | 
				
			||||||
                self.parser = self.discard()
 | 
					 | 
				
			||||||
                next(self.parser)  # start coroutine
 | 
					 | 
				
			||||||
                yield
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            if self.debug:
 | 
					 | 
				
			||||||
                self.logger.debug("< GET %s HTTP/1.1", request.path)
 | 
					 | 
				
			||||||
                for key, value in request.headers.raw_items():
 | 
					 | 
				
			||||||
                    self.logger.debug("< %s: %s", key, value)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            self.events.append(request)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        yield from super().parse()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class ServerConnection(ServerProtocol):
 | 
					 | 
				
			||||||
    def __init__(self, *args: Any, **kwargs: Any) -> None:
 | 
					 | 
				
			||||||
        warnings.warn(  # deprecated in 11.0 - 2023-04-02
 | 
					 | 
				
			||||||
            "ServerConnection was renamed to ServerProtocol",
 | 
					 | 
				
			||||||
            DeprecationWarning,
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        super().__init__(*args, **kwargs)
 | 
					 | 
				
			||||||
@@ -1,222 +0,0 @@
 | 
				
			|||||||
/* C implementation of performance sensitive functions. */
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
#define PY_SSIZE_T_CLEAN
 | 
					 | 
				
			||||||
#include <Python.h>
 | 
					 | 
				
			||||||
#include <stdint.h> /* uint8_t, uint32_t, uint64_t */
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
#if __ARM_NEON
 | 
					 | 
				
			||||||
#include <arm_neon.h>
 | 
					 | 
				
			||||||
#elif __SSE2__
 | 
					 | 
				
			||||||
#include <emmintrin.h>
 | 
					 | 
				
			||||||
#endif
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
static const Py_ssize_t MASK_LEN = 4;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
/* Similar to PyBytes_AsStringAndSize, but accepts more types */
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
static int
 | 
					 | 
				
			||||||
_PyBytesLike_AsStringAndSize(PyObject *obj, PyObject **tmp, char **buffer, Py_ssize_t *length)
 | 
					 | 
				
			||||||
{
 | 
					 | 
				
			||||||
    // This supports bytes, bytearrays, and memoryview objects,
 | 
					 | 
				
			||||||
    // which are common data structures for handling byte streams.
 | 
					 | 
				
			||||||
    // If *tmp isn't NULL, the caller gets a new reference.
 | 
					 | 
				
			||||||
    if (PyBytes_Check(obj))
 | 
					 | 
				
			||||||
    {
 | 
					 | 
				
			||||||
        *tmp = NULL;
 | 
					 | 
				
			||||||
        *buffer = PyBytes_AS_STRING(obj);
 | 
					 | 
				
			||||||
        *length = PyBytes_GET_SIZE(obj);
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
    else if (PyByteArray_Check(obj))
 | 
					 | 
				
			||||||
    {
 | 
					 | 
				
			||||||
        *tmp = NULL;
 | 
					 | 
				
			||||||
        *buffer = PyByteArray_AS_STRING(obj);
 | 
					 | 
				
			||||||
        *length = PyByteArray_GET_SIZE(obj);
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
    else if (PyMemoryView_Check(obj))
 | 
					 | 
				
			||||||
    {
 | 
					 | 
				
			||||||
        *tmp = PyMemoryView_GetContiguous(obj, PyBUF_READ, 'C');
 | 
					 | 
				
			||||||
        if (*tmp == NULL)
 | 
					 | 
				
			||||||
        {
 | 
					 | 
				
			||||||
            return -1;
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
        Py_buffer *mv_buf;
 | 
					 | 
				
			||||||
        mv_buf = PyMemoryView_GET_BUFFER(*tmp);
 | 
					 | 
				
			||||||
        *buffer = mv_buf->buf;
 | 
					 | 
				
			||||||
        *length = mv_buf->len;
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
    else
 | 
					 | 
				
			||||||
    {
 | 
					 | 
				
			||||||
        PyErr_Format(
 | 
					 | 
				
			||||||
            PyExc_TypeError,
 | 
					 | 
				
			||||||
            "expected a bytes-like object, %.200s found",
 | 
					 | 
				
			||||||
            Py_TYPE(obj)->tp_name);
 | 
					 | 
				
			||||||
        return -1;
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    return 0;
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
/* C implementation of websockets.utils.apply_mask */
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
static PyObject *
 | 
					 | 
				
			||||||
apply_mask(PyObject *self, PyObject *args, PyObject *kwds)
 | 
					 | 
				
			||||||
{
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    // In order to support various bytes-like types, accept any Python object.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    static char *kwlist[] = {"data", "mask", NULL};
 | 
					 | 
				
			||||||
    PyObject *input_obj;
 | 
					 | 
				
			||||||
    PyObject *mask_obj;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    // A pointer to a char * + length will be extracted from the data and mask
 | 
					 | 
				
			||||||
    // arguments, possibly via a Py_buffer.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    PyObject *input_tmp = NULL;
 | 
					 | 
				
			||||||
    char *input;
 | 
					 | 
				
			||||||
    Py_ssize_t input_len;
 | 
					 | 
				
			||||||
    PyObject *mask_tmp = NULL;
 | 
					 | 
				
			||||||
    char *mask;
 | 
					 | 
				
			||||||
    Py_ssize_t mask_len;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    // Initialize a PyBytesObject then get a pointer to the underlying char *
 | 
					 | 
				
			||||||
    // in order to avoid an extra memory copy in PyBytes_FromStringAndSize.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    PyObject *result = NULL;
 | 
					 | 
				
			||||||
    char *output;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    // Other variables.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Py_ssize_t i = 0;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    // Parse inputs.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if (!PyArg_ParseTupleAndKeywords(
 | 
					 | 
				
			||||||
            args, kwds, "OO", kwlist, &input_obj, &mask_obj))
 | 
					 | 
				
			||||||
    {
 | 
					 | 
				
			||||||
        goto exit;
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if (_PyBytesLike_AsStringAndSize(input_obj, &input_tmp, &input, &input_len) == -1)
 | 
					 | 
				
			||||||
    {
 | 
					 | 
				
			||||||
        goto exit;
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if (_PyBytesLike_AsStringAndSize(mask_obj, &mask_tmp, &mask, &mask_len) == -1)
 | 
					 | 
				
			||||||
    {
 | 
					 | 
				
			||||||
        goto exit;
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if (mask_len != MASK_LEN)
 | 
					 | 
				
			||||||
    {
 | 
					 | 
				
			||||||
        PyErr_SetString(PyExc_ValueError, "mask must contain 4 bytes");
 | 
					 | 
				
			||||||
        goto exit;
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    // Create output.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    result = PyBytes_FromStringAndSize(NULL, input_len);
 | 
					 | 
				
			||||||
    if (result == NULL)
 | 
					 | 
				
			||||||
    {
 | 
					 | 
				
			||||||
        goto exit;
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    // Since we just created result, we don't need error checks.
 | 
					 | 
				
			||||||
    output = PyBytes_AS_STRING(result);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    // Perform the masking operation.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    // Apparently GCC cannot figure out the following optimizations by itself.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    // We need a new scope for MSVC 2010 (non C99 friendly)
 | 
					 | 
				
			||||||
    {
 | 
					 | 
				
			||||||
#if __ARM_NEON
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        // With NEON support, XOR by blocks of 16 bytes = 128 bits.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Py_ssize_t input_len_128 = input_len & ~15;
 | 
					 | 
				
			||||||
        uint8x16_t mask_128 = vreinterpretq_u8_u32(vdupq_n_u32(*(uint32_t *)mask));
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        for (; i < input_len_128; i += 16)
 | 
					 | 
				
			||||||
        {
 | 
					 | 
				
			||||||
            uint8x16_t in_128 = vld1q_u8((uint8_t *)(input + i));
 | 
					 | 
				
			||||||
            uint8x16_t out_128 = veorq_u8(in_128, mask_128);
 | 
					 | 
				
			||||||
            vst1q_u8((uint8_t *)(output + i), out_128);
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
#elif __SSE2__
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        // With SSE2 support, XOR by blocks of 16 bytes = 128 bits.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        // Since we cannot control the 16-bytes alignment of input and output
 | 
					 | 
				
			||||||
        // buffers, we rely on loadu/storeu rather than load/store.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Py_ssize_t input_len_128 = input_len & ~15;
 | 
					 | 
				
			||||||
        __m128i mask_128 = _mm_set1_epi32(*(uint32_t *)mask);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        for (; i < input_len_128; i += 16)
 | 
					 | 
				
			||||||
        {
 | 
					 | 
				
			||||||
            __m128i in_128 = _mm_loadu_si128((__m128i *)(input + i));
 | 
					 | 
				
			||||||
            __m128i out_128 = _mm_xor_si128(in_128, mask_128);
 | 
					 | 
				
			||||||
            _mm_storeu_si128((__m128i *)(output + i), out_128);
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
#else
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        // Without SSE2 support, XOR by blocks of 8 bytes = 64 bits.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        // We assume the memory allocator aligns everything on 8 bytes boundaries.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Py_ssize_t input_len_64 = input_len & ~7;
 | 
					 | 
				
			||||||
        uint32_t mask_32 = *(uint32_t *)mask;
 | 
					 | 
				
			||||||
        uint64_t mask_64 = ((uint64_t)mask_32 << 32) | (uint64_t)mask_32;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        for (; i < input_len_64; i += 8)
 | 
					 | 
				
			||||||
        {
 | 
					 | 
				
			||||||
            *(uint64_t *)(output + i) = *(uint64_t *)(input + i) ^ mask_64;
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
#endif
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    // XOR the remainder of the input byte by byte.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    for (; i < input_len; i++)
 | 
					 | 
				
			||||||
    {
 | 
					 | 
				
			||||||
        output[i] = input[i] ^ mask[i & (MASK_LEN - 1)];
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
exit:
 | 
					 | 
				
			||||||
    Py_XDECREF(input_tmp);
 | 
					 | 
				
			||||||
    Py_XDECREF(mask_tmp);
 | 
					 | 
				
			||||||
    return result;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
static PyMethodDef speedups_methods[] = {
 | 
					 | 
				
			||||||
    {
 | 
					 | 
				
			||||||
        "apply_mask",
 | 
					 | 
				
			||||||
        (PyCFunction)apply_mask,
 | 
					 | 
				
			||||||
        METH_VARARGS | METH_KEYWORDS,
 | 
					 | 
				
			||||||
        "Apply masking to the data of a WebSocket message.",
 | 
					 | 
				
			||||||
    },
 | 
					 | 
				
			||||||
    {NULL, NULL, 0, NULL},      /* Sentinel */
 | 
					 | 
				
			||||||
};
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
static struct PyModuleDef speedups_module = {
 | 
					 | 
				
			||||||
    PyModuleDef_HEAD_INIT,
 | 
					 | 
				
			||||||
    "websocket.speedups",       /* m_name */
 | 
					 | 
				
			||||||
    "C implementation of performance sensitive functions.",
 | 
					 | 
				
			||||||
                                /* m_doc */
 | 
					 | 
				
			||||||
    -1,                         /* m_size */
 | 
					 | 
				
			||||||
    speedups_methods,           /* m_methods */
 | 
					 | 
				
			||||||
    NULL,
 | 
					 | 
				
			||||||
    NULL,
 | 
					 | 
				
			||||||
    NULL,
 | 
					 | 
				
			||||||
    NULL
 | 
					 | 
				
			||||||
};
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
PyMODINIT_FUNC
 | 
					 | 
				
			||||||
PyInit_speedups(void)
 | 
					 | 
				
			||||||
{
 | 
					 | 
				
			||||||
    return PyModule_Create(&speedups_module);
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
@@ -1 +0,0 @@
 | 
				
			|||||||
def apply_mask(data: bytes, mask: bytes) -> bytes: ...
 | 
					 | 
				
			||||||
@@ -1,151 +0,0 @@
 | 
				
			|||||||
from __future__ import annotations
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from typing import Generator
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class StreamReader:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Generator-based stream reader.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    This class doesn't support concurrent calls to :meth:`read_line`,
 | 
					 | 
				
			||||||
    :meth:`read_exact`, or :meth:`read_to_eof`. Make sure calls are
 | 
					 | 
				
			||||||
    serialized.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(self) -> None:
 | 
					 | 
				
			||||||
        self.buffer = bytearray()
 | 
					 | 
				
			||||||
        self.eof = False
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def read_line(self, m: int) -> Generator[None, None, bytes]:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Read a LF-terminated line from the stream.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        This is a generator-based coroutine.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        The return value includes the LF character.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Args:
 | 
					 | 
				
			||||||
            m: Maximum number bytes to read; this is a security limit.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Raises:
 | 
					 | 
				
			||||||
            EOFError: If the stream ends without a LF.
 | 
					 | 
				
			||||||
            RuntimeError: If the stream ends in more than ``m`` bytes.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        n = 0  # number of bytes to read
 | 
					 | 
				
			||||||
        p = 0  # number of bytes without a newline
 | 
					 | 
				
			||||||
        while True:
 | 
					 | 
				
			||||||
            n = self.buffer.find(b"\n", p) + 1
 | 
					 | 
				
			||||||
            if n > 0:
 | 
					 | 
				
			||||||
                break
 | 
					 | 
				
			||||||
            p = len(self.buffer)
 | 
					 | 
				
			||||||
            if p > m:
 | 
					 | 
				
			||||||
                raise RuntimeError(f"read {p} bytes, expected no more than {m} bytes")
 | 
					 | 
				
			||||||
            if self.eof:
 | 
					 | 
				
			||||||
                raise EOFError(f"stream ends after {p} bytes, before end of line")
 | 
					 | 
				
			||||||
            yield
 | 
					 | 
				
			||||||
        if n > m:
 | 
					 | 
				
			||||||
            raise RuntimeError(f"read {n} bytes, expected no more than {m} bytes")
 | 
					 | 
				
			||||||
        r = self.buffer[:n]
 | 
					 | 
				
			||||||
        del self.buffer[:n]
 | 
					 | 
				
			||||||
        return r
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def read_exact(self, n: int) -> Generator[None, None, bytes]:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Read a given number of bytes from the stream.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        This is a generator-based coroutine.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Args:
 | 
					 | 
				
			||||||
            n: How many bytes to read.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Raises:
 | 
					 | 
				
			||||||
            EOFError: If the stream ends in less than ``n`` bytes.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        assert n >= 0
 | 
					 | 
				
			||||||
        while len(self.buffer) < n:
 | 
					 | 
				
			||||||
            if self.eof:
 | 
					 | 
				
			||||||
                p = len(self.buffer)
 | 
					 | 
				
			||||||
                raise EOFError(f"stream ends after {p} bytes, expected {n} bytes")
 | 
					 | 
				
			||||||
            yield
 | 
					 | 
				
			||||||
        r = self.buffer[:n]
 | 
					 | 
				
			||||||
        del self.buffer[:n]
 | 
					 | 
				
			||||||
        return r
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def read_to_eof(self, m: int) -> Generator[None, None, bytes]:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Read all bytes from the stream.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        This is a generator-based coroutine.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Args:
 | 
					 | 
				
			||||||
            m: Maximum number bytes to read; this is a security limit.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Raises:
 | 
					 | 
				
			||||||
            RuntimeError: If the stream ends in more than ``m`` bytes.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        while not self.eof:
 | 
					 | 
				
			||||||
            p = len(self.buffer)
 | 
					 | 
				
			||||||
            if p > m:
 | 
					 | 
				
			||||||
                raise RuntimeError(f"read {p} bytes, expected no more than {m} bytes")
 | 
					 | 
				
			||||||
            yield
 | 
					 | 
				
			||||||
        r = self.buffer[:]
 | 
					 | 
				
			||||||
        del self.buffer[:]
 | 
					 | 
				
			||||||
        return r
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def at_eof(self) -> Generator[None, None, bool]:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Tell whether the stream has ended and all data was read.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        This is a generator-based coroutine.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        while True:
 | 
					 | 
				
			||||||
            if self.buffer:
 | 
					 | 
				
			||||||
                return False
 | 
					 | 
				
			||||||
            if self.eof:
 | 
					 | 
				
			||||||
                return True
 | 
					 | 
				
			||||||
            # When all data was read but the stream hasn't ended, we can't
 | 
					 | 
				
			||||||
            # tell if until either feed_data() or feed_eof() is called.
 | 
					 | 
				
			||||||
            yield
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def feed_data(self, data: bytes) -> None:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Write data to the stream.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        :meth:`feed_data` cannot be called after :meth:`feed_eof`.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Args:
 | 
					 | 
				
			||||||
            data: Data to write.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Raises:
 | 
					 | 
				
			||||||
            EOFError: If the stream has ended.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        if self.eof:
 | 
					 | 
				
			||||||
            raise EOFError("stream ended")
 | 
					 | 
				
			||||||
        self.buffer += data
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def feed_eof(self) -> None:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        End the stream.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        :meth:`feed_eof` cannot be called more than once.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Raises:
 | 
					 | 
				
			||||||
            EOFError: If the stream has ended.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        if self.eof:
 | 
					 | 
				
			||||||
            raise EOFError("stream ended")
 | 
					 | 
				
			||||||
        self.eof = True
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def discard(self) -> None:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Discard all buffered data, but don't end the stream.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        del self.buffer[:]
 | 
					 | 
				
			||||||
@@ -1,336 +0,0 @@
 | 
				
			|||||||
from __future__ import annotations
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import socket
 | 
					 | 
				
			||||||
import ssl as ssl_module
 | 
					 | 
				
			||||||
import threading
 | 
					 | 
				
			||||||
import warnings
 | 
					 | 
				
			||||||
from typing import Any, Sequence
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from ..client import ClientProtocol
 | 
					 | 
				
			||||||
from ..datastructures import HeadersLike
 | 
					 | 
				
			||||||
from ..extensions.base import ClientExtensionFactory
 | 
					 | 
				
			||||||
from ..extensions.permessage_deflate import enable_client_permessage_deflate
 | 
					 | 
				
			||||||
from ..headers import validate_subprotocols
 | 
					 | 
				
			||||||
from ..http11 import USER_AGENT, Response
 | 
					 | 
				
			||||||
from ..protocol import CONNECTING, Event
 | 
					 | 
				
			||||||
from ..typing import LoggerLike, Origin, Subprotocol
 | 
					 | 
				
			||||||
from ..uri import parse_uri
 | 
					 | 
				
			||||||
from .connection import Connection
 | 
					 | 
				
			||||||
from .utils import Deadline
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
__all__ = ["connect", "unix_connect", "ClientConnection"]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class ClientConnection(Connection):
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    :mod:`threading` implementation of a WebSocket client connection.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    :class:`ClientConnection` provides :meth:`recv` and :meth:`send` methods for
 | 
					 | 
				
			||||||
    receiving and sending messages.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    It supports iteration to receive messages::
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        for message in websocket:
 | 
					 | 
				
			||||||
            process(message)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    The iterator exits normally when the connection is closed with close code
 | 
					 | 
				
			||||||
    1000 (OK) or 1001 (going away) or without a close code. It raises a
 | 
					 | 
				
			||||||
    :exc:`~websockets.exceptions.ConnectionClosedError` when the connection is
 | 
					 | 
				
			||||||
    closed with any other code.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Args:
 | 
					 | 
				
			||||||
        socket: Socket connected to a WebSocket server.
 | 
					 | 
				
			||||||
        protocol: Sans-I/O connection.
 | 
					 | 
				
			||||||
        close_timeout: Timeout for closing the connection in seconds.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        socket: socket.socket,
 | 
					 | 
				
			||||||
        protocol: ClientProtocol,
 | 
					 | 
				
			||||||
        *,
 | 
					 | 
				
			||||||
        close_timeout: float | None = 10,
 | 
					 | 
				
			||||||
    ) -> None:
 | 
					 | 
				
			||||||
        self.protocol: ClientProtocol
 | 
					 | 
				
			||||||
        self.response_rcvd = threading.Event()
 | 
					 | 
				
			||||||
        super().__init__(
 | 
					 | 
				
			||||||
            socket,
 | 
					 | 
				
			||||||
            protocol,
 | 
					 | 
				
			||||||
            close_timeout=close_timeout,
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def handshake(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        additional_headers: HeadersLike | None = None,
 | 
					 | 
				
			||||||
        user_agent_header: str | None = USER_AGENT,
 | 
					 | 
				
			||||||
        timeout: float | None = None,
 | 
					 | 
				
			||||||
    ) -> None:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Perform the opening handshake.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        with self.send_context(expected_state=CONNECTING):
 | 
					 | 
				
			||||||
            self.request = self.protocol.connect()
 | 
					 | 
				
			||||||
            if additional_headers is not None:
 | 
					 | 
				
			||||||
                self.request.headers.update(additional_headers)
 | 
					 | 
				
			||||||
            if user_agent_header is not None:
 | 
					 | 
				
			||||||
                self.request.headers["User-Agent"] = user_agent_header
 | 
					 | 
				
			||||||
            self.protocol.send_request(self.request)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if not self.response_rcvd.wait(timeout):
 | 
					 | 
				
			||||||
            raise TimeoutError("timed out during handshake")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # self.protocol.handshake_exc is always set when the connection is lost
 | 
					 | 
				
			||||||
        # before receiving a response, when the response cannot be parsed, or
 | 
					 | 
				
			||||||
        # when the response fails the handshake.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if self.protocol.handshake_exc is not None:
 | 
					 | 
				
			||||||
            raise self.protocol.handshake_exc
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def process_event(self, event: Event) -> None:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Process one incoming event.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        # First event - handshake response.
 | 
					 | 
				
			||||||
        if self.response is None:
 | 
					 | 
				
			||||||
            assert isinstance(event, Response)
 | 
					 | 
				
			||||||
            self.response = event
 | 
					 | 
				
			||||||
            self.response_rcvd.set()
 | 
					 | 
				
			||||||
        # Later events - frames.
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            super().process_event(event)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def recv_events(self) -> None:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Read incoming data from the socket and process events.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        try:
 | 
					 | 
				
			||||||
            super().recv_events()
 | 
					 | 
				
			||||||
        finally:
 | 
					 | 
				
			||||||
            # If the connection is closed during the handshake, unblock it.
 | 
					 | 
				
			||||||
            self.response_rcvd.set()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def connect(
 | 
					 | 
				
			||||||
    uri: str,
 | 
					 | 
				
			||||||
    *,
 | 
					 | 
				
			||||||
    # TCP/TLS
 | 
					 | 
				
			||||||
    sock: socket.socket | None = None,
 | 
					 | 
				
			||||||
    ssl: ssl_module.SSLContext | None = None,
 | 
					 | 
				
			||||||
    server_hostname: str | None = None,
 | 
					 | 
				
			||||||
    # WebSocket
 | 
					 | 
				
			||||||
    origin: Origin | None = None,
 | 
					 | 
				
			||||||
    extensions: Sequence[ClientExtensionFactory] | None = None,
 | 
					 | 
				
			||||||
    subprotocols: Sequence[Subprotocol] | None = None,
 | 
					 | 
				
			||||||
    additional_headers: HeadersLike | None = None,
 | 
					 | 
				
			||||||
    user_agent_header: str | None = USER_AGENT,
 | 
					 | 
				
			||||||
    compression: str | None = "deflate",
 | 
					 | 
				
			||||||
    # Timeouts
 | 
					 | 
				
			||||||
    open_timeout: float | None = 10,
 | 
					 | 
				
			||||||
    close_timeout: float | None = 10,
 | 
					 | 
				
			||||||
    # Limits
 | 
					 | 
				
			||||||
    max_size: int | None = 2**20,
 | 
					 | 
				
			||||||
    # Logging
 | 
					 | 
				
			||||||
    logger: LoggerLike | None = None,
 | 
					 | 
				
			||||||
    # Escape hatch for advanced customization
 | 
					 | 
				
			||||||
    create_connection: type[ClientConnection] | None = None,
 | 
					 | 
				
			||||||
    **kwargs: Any,
 | 
					 | 
				
			||||||
) -> ClientConnection:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Connect to the WebSocket server at ``uri``.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    This function returns a :class:`ClientConnection` instance, which you can
 | 
					 | 
				
			||||||
    use to send and receive messages.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    :func:`connect` may be used as a context manager::
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        from websockets.sync.client import connect
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        with connect(...) as websocket:
 | 
					 | 
				
			||||||
            ...
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    The connection is closed automatically when exiting the context.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Args:
 | 
					 | 
				
			||||||
        uri: URI of the WebSocket server.
 | 
					 | 
				
			||||||
        sock: Preexisting TCP socket. ``sock`` overrides the host and port
 | 
					 | 
				
			||||||
            from ``uri``. You may call :func:`socket.create_connection` to
 | 
					 | 
				
			||||||
            create a suitable TCP socket.
 | 
					 | 
				
			||||||
        ssl: Configuration for enabling TLS on the connection.
 | 
					 | 
				
			||||||
        server_hostname: Host name for the TLS handshake. ``server_hostname``
 | 
					 | 
				
			||||||
            overrides the host name from ``uri``.
 | 
					 | 
				
			||||||
        origin: Value of the ``Origin`` header, for servers that require it.
 | 
					 | 
				
			||||||
        extensions: List of supported extensions, in order in which they
 | 
					 | 
				
			||||||
            should be negotiated and run.
 | 
					 | 
				
			||||||
        subprotocols: List of supported subprotocols, in order of decreasing
 | 
					 | 
				
			||||||
            preference.
 | 
					 | 
				
			||||||
        additional_headers (HeadersLike | None): Arbitrary HTTP headers to add
 | 
					 | 
				
			||||||
            to the handshake request.
 | 
					 | 
				
			||||||
        user_agent_header: Value of  the ``User-Agent`` request header.
 | 
					 | 
				
			||||||
            It defaults to ``"Python/x.y.z websockets/X.Y"``.
 | 
					 | 
				
			||||||
            Setting it to :obj:`None` removes the header.
 | 
					 | 
				
			||||||
        compression: The "permessage-deflate" extension is enabled by default.
 | 
					 | 
				
			||||||
            Set ``compression`` to :obj:`None` to disable it. See the
 | 
					 | 
				
			||||||
            :doc:`compression guide <../../topics/compression>` for details.
 | 
					 | 
				
			||||||
        open_timeout: Timeout for opening the connection in seconds.
 | 
					 | 
				
			||||||
            :obj:`None` disables the timeout.
 | 
					 | 
				
			||||||
        close_timeout: Timeout for closing the connection in seconds.
 | 
					 | 
				
			||||||
            :obj:`None` disables the timeout.
 | 
					 | 
				
			||||||
        max_size: Maximum size of incoming messages in bytes.
 | 
					 | 
				
			||||||
            :obj:`None` disables the limit.
 | 
					 | 
				
			||||||
        logger: Logger for this client.
 | 
					 | 
				
			||||||
            It defaults to ``logging.getLogger("websockets.client")``.
 | 
					 | 
				
			||||||
            See the :doc:`logging guide <../../topics/logging>` for details.
 | 
					 | 
				
			||||||
        create_connection: Factory for the :class:`ClientConnection` managing
 | 
					 | 
				
			||||||
            the connection. Set it to a wrapper or a subclass to customize
 | 
					 | 
				
			||||||
            connection handling.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Any other keyword arguments are passed to :func:`~socket.create_connection`.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Raises:
 | 
					 | 
				
			||||||
        InvalidURI: If ``uri`` isn't a valid WebSocket URI.
 | 
					 | 
				
			||||||
        OSError: If the TCP connection fails.
 | 
					 | 
				
			||||||
        InvalidHandshake: If the opening handshake fails.
 | 
					 | 
				
			||||||
        TimeoutError: If the opening handshake times out.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Process parameters
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Backwards compatibility: ssl used to be called ssl_context.
 | 
					 | 
				
			||||||
    if ssl is None and "ssl_context" in kwargs:
 | 
					 | 
				
			||||||
        ssl = kwargs.pop("ssl_context")
 | 
					 | 
				
			||||||
        warnings.warn(  # deprecated in 13.0 - 2024-08-20
 | 
					 | 
				
			||||||
            "ssl_context was renamed to ssl",
 | 
					 | 
				
			||||||
            DeprecationWarning,
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    wsuri = parse_uri(uri)
 | 
					 | 
				
			||||||
    if not wsuri.secure and ssl is not None:
 | 
					 | 
				
			||||||
        raise TypeError("ssl argument is incompatible with a ws:// URI")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Private APIs for unix_connect()
 | 
					 | 
				
			||||||
    unix: bool = kwargs.pop("unix", False)
 | 
					 | 
				
			||||||
    path: str | None = kwargs.pop("path", None)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if unix:
 | 
					 | 
				
			||||||
        if path is None and sock is None:
 | 
					 | 
				
			||||||
            raise TypeError("missing path argument")
 | 
					 | 
				
			||||||
        elif path is not None and sock is not None:
 | 
					 | 
				
			||||||
            raise TypeError("path and sock arguments are incompatible")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if subprotocols is not None:
 | 
					 | 
				
			||||||
        validate_subprotocols(subprotocols)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if compression == "deflate":
 | 
					 | 
				
			||||||
        extensions = enable_client_permessage_deflate(extensions)
 | 
					 | 
				
			||||||
    elif compression is not None:
 | 
					 | 
				
			||||||
        raise ValueError(f"unsupported compression: {compression}")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Calculate timeouts on the TCP, TLS, and WebSocket handshakes.
 | 
					 | 
				
			||||||
    # The TCP and TLS timeouts must be set on the socket, then removed
 | 
					 | 
				
			||||||
    # to avoid conflicting with the WebSocket timeout in handshake().
 | 
					 | 
				
			||||||
    deadline = Deadline(open_timeout)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if create_connection is None:
 | 
					 | 
				
			||||||
        create_connection = ClientConnection
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    try:
 | 
					 | 
				
			||||||
        # Connect socket
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if sock is None:
 | 
					 | 
				
			||||||
            if unix:
 | 
					 | 
				
			||||||
                sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
 | 
					 | 
				
			||||||
                sock.settimeout(deadline.timeout())
 | 
					 | 
				
			||||||
                assert path is not None  # mypy cannot figure this out
 | 
					 | 
				
			||||||
                sock.connect(path)
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                kwargs.setdefault("timeout", deadline.timeout())
 | 
					 | 
				
			||||||
                sock = socket.create_connection((wsuri.host, wsuri.port), **kwargs)
 | 
					 | 
				
			||||||
            sock.settimeout(None)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Disable Nagle algorithm
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if not unix:
 | 
					 | 
				
			||||||
            sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, True)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Initialize TLS wrapper and perform TLS handshake
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if wsuri.secure:
 | 
					 | 
				
			||||||
            if ssl is None:
 | 
					 | 
				
			||||||
                ssl = ssl_module.create_default_context()
 | 
					 | 
				
			||||||
            if server_hostname is None:
 | 
					 | 
				
			||||||
                server_hostname = wsuri.host
 | 
					 | 
				
			||||||
            sock.settimeout(deadline.timeout())
 | 
					 | 
				
			||||||
            sock = ssl.wrap_socket(sock, server_hostname=server_hostname)
 | 
					 | 
				
			||||||
            sock.settimeout(None)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Initialize WebSocket protocol
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        protocol = ClientProtocol(
 | 
					 | 
				
			||||||
            wsuri,
 | 
					 | 
				
			||||||
            origin=origin,
 | 
					 | 
				
			||||||
            extensions=extensions,
 | 
					 | 
				
			||||||
            subprotocols=subprotocols,
 | 
					 | 
				
			||||||
            max_size=max_size,
 | 
					 | 
				
			||||||
            logger=logger,
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Initialize WebSocket connection
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        connection = create_connection(
 | 
					 | 
				
			||||||
            sock,
 | 
					 | 
				
			||||||
            protocol,
 | 
					 | 
				
			||||||
            close_timeout=close_timeout,
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
    except Exception:
 | 
					 | 
				
			||||||
        if sock is not None:
 | 
					 | 
				
			||||||
            sock.close()
 | 
					 | 
				
			||||||
        raise
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    try:
 | 
					 | 
				
			||||||
        connection.handshake(
 | 
					 | 
				
			||||||
            additional_headers,
 | 
					 | 
				
			||||||
            user_agent_header,
 | 
					 | 
				
			||||||
            deadline.timeout(),
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
    except Exception:
 | 
					 | 
				
			||||||
        connection.close_socket()
 | 
					 | 
				
			||||||
        connection.recv_events_thread.join()
 | 
					 | 
				
			||||||
        raise
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    return connection
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def unix_connect(
 | 
					 | 
				
			||||||
    path: str | None = None,
 | 
					 | 
				
			||||||
    uri: str | None = None,
 | 
					 | 
				
			||||||
    **kwargs: Any,
 | 
					 | 
				
			||||||
) -> ClientConnection:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Connect to a WebSocket server listening on a Unix socket.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    This function accepts the same keyword arguments as :func:`connect`.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    It's only available on Unix.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    It's mainly useful for debugging servers listening on Unix sockets.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Args:
 | 
					 | 
				
			||||||
        path: File system path to the Unix socket.
 | 
					 | 
				
			||||||
        uri: URI of the WebSocket server. ``uri`` defaults to
 | 
					 | 
				
			||||||
            ``ws://localhost/`` or, when a ``ssl`` is provided, to
 | 
					 | 
				
			||||||
            ``wss://localhost/``.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    if uri is None:
 | 
					 | 
				
			||||||
        # Backwards compatibility: ssl used to be called ssl_context.
 | 
					 | 
				
			||||||
        if kwargs.get("ssl") is None and kwargs.get("ssl_context") is None:
 | 
					 | 
				
			||||||
            uri = "ws://localhost/"
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            uri = "wss://localhost/"
 | 
					 | 
				
			||||||
    return connect(uri=uri, unix=True, path=path, **kwargs)
 | 
					 | 
				
			||||||
@@ -1,791 +0,0 @@
 | 
				
			|||||||
from __future__ import annotations
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import contextlib
 | 
					 | 
				
			||||||
import logging
 | 
					 | 
				
			||||||
import random
 | 
					 | 
				
			||||||
import socket
 | 
					 | 
				
			||||||
import struct
 | 
					 | 
				
			||||||
import threading
 | 
					 | 
				
			||||||
import uuid
 | 
					 | 
				
			||||||
from types import TracebackType
 | 
					 | 
				
			||||||
from typing import Any, Iterable, Iterator, Mapping
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from ..exceptions import (
 | 
					 | 
				
			||||||
    ConcurrencyError,
 | 
					 | 
				
			||||||
    ConnectionClosed,
 | 
					 | 
				
			||||||
    ConnectionClosedOK,
 | 
					 | 
				
			||||||
    ProtocolError,
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
from ..frames import DATA_OPCODES, BytesLike, CloseCode, Frame, Opcode
 | 
					 | 
				
			||||||
from ..http11 import Request, Response
 | 
					 | 
				
			||||||
from ..protocol import CLOSED, OPEN, Event, Protocol, State
 | 
					 | 
				
			||||||
from ..typing import Data, LoggerLike, Subprotocol
 | 
					 | 
				
			||||||
from .messages import Assembler
 | 
					 | 
				
			||||||
from .utils import Deadline
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
__all__ = ["Connection"]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class Connection:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    :mod:`threading` implementation of a WebSocket connection.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    :class:`Connection` provides APIs shared between WebSocket servers and
 | 
					 | 
				
			||||||
    clients.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    You shouldn't use it directly. Instead, use
 | 
					 | 
				
			||||||
    :class:`~websockets.sync.client.ClientConnection` or
 | 
					 | 
				
			||||||
    :class:`~websockets.sync.server.ServerConnection`.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    recv_bufsize = 65536
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        socket: socket.socket,
 | 
					 | 
				
			||||||
        protocol: Protocol,
 | 
					 | 
				
			||||||
        *,
 | 
					 | 
				
			||||||
        close_timeout: float | None = 10,
 | 
					 | 
				
			||||||
    ) -> None:
 | 
					 | 
				
			||||||
        self.socket = socket
 | 
					 | 
				
			||||||
        self.protocol = protocol
 | 
					 | 
				
			||||||
        self.close_timeout = close_timeout
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Inject reference to this instance in the protocol's logger.
 | 
					 | 
				
			||||||
        self.protocol.logger = logging.LoggerAdapter(
 | 
					 | 
				
			||||||
            self.protocol.logger,
 | 
					 | 
				
			||||||
            {"websocket": self},
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Copy attributes from the protocol for convenience.
 | 
					 | 
				
			||||||
        self.id: uuid.UUID = self.protocol.id
 | 
					 | 
				
			||||||
        """Unique identifier of the connection. Useful in logs."""
 | 
					 | 
				
			||||||
        self.logger: LoggerLike = self.protocol.logger
 | 
					 | 
				
			||||||
        """Logger for this connection."""
 | 
					 | 
				
			||||||
        self.debug = self.protocol.debug
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # HTTP handshake request and response.
 | 
					 | 
				
			||||||
        self.request: Request | None = None
 | 
					 | 
				
			||||||
        """Opening handshake request."""
 | 
					 | 
				
			||||||
        self.response: Response | None = None
 | 
					 | 
				
			||||||
        """Opening handshake response."""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Mutex serializing interactions with the protocol.
 | 
					 | 
				
			||||||
        self.protocol_mutex = threading.Lock()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Assembler turning frames into messages and serializing reads.
 | 
					 | 
				
			||||||
        self.recv_messages = Assembler()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Whether we are busy sending a fragmented message.
 | 
					 | 
				
			||||||
        self.send_in_progress = False
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Deadline for the closing handshake.
 | 
					 | 
				
			||||||
        self.close_deadline: Deadline | None = None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Mapping of ping IDs to pong waiters, in chronological order.
 | 
					 | 
				
			||||||
        self.ping_waiters: dict[bytes, threading.Event] = {}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Receiving events from the socket. This thread explicitly is marked as
 | 
					 | 
				
			||||||
        # to support creating a connection in a non-daemon thread then using it
 | 
					 | 
				
			||||||
        # in a daemon thread; this shouldn't block the intpreter from exiting.
 | 
					 | 
				
			||||||
        self.recv_events_thread = threading.Thread(
 | 
					 | 
				
			||||||
            target=self.recv_events,
 | 
					 | 
				
			||||||
            daemon=True,
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        self.recv_events_thread.start()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Exception raised in recv_events, to be chained to ConnectionClosed
 | 
					 | 
				
			||||||
        # in the user thread in order to show why the TCP connection dropped.
 | 
					 | 
				
			||||||
        self.recv_exc: BaseException | None = None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Public attributes
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @property
 | 
					 | 
				
			||||||
    def local_address(self) -> Any:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Local address of the connection.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        For IPv4 connections, this is a ``(host, port)`` tuple.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        The format of the address depends on the address family.
 | 
					 | 
				
			||||||
        See :meth:`~socket.socket.getsockname`.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        return self.socket.getsockname()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @property
 | 
					 | 
				
			||||||
    def remote_address(self) -> Any:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Remote address of the connection.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        For IPv4 connections, this is a ``(host, port)`` tuple.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        The format of the address depends on the address family.
 | 
					 | 
				
			||||||
        See :meth:`~socket.socket.getpeername`.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        return self.socket.getpeername()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @property
 | 
					 | 
				
			||||||
    def subprotocol(self) -> Subprotocol | None:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Subprotocol negotiated during the opening handshake.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        :obj:`None` if no subprotocol was negotiated.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        return self.protocol.subprotocol
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Public methods
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __enter__(self) -> Connection:
 | 
					 | 
				
			||||||
        return self
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __exit__(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        exc_type: type[BaseException] | None,
 | 
					 | 
				
			||||||
        exc_value: BaseException | None,
 | 
					 | 
				
			||||||
        traceback: TracebackType | None,
 | 
					 | 
				
			||||||
    ) -> None:
 | 
					 | 
				
			||||||
        if exc_type is None:
 | 
					 | 
				
			||||||
            self.close()
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            self.close(CloseCode.INTERNAL_ERROR)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __iter__(self) -> Iterator[Data]:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Iterate on incoming messages.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        The iterator calls :meth:`recv` and yields messages in an infinite loop.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        It exits when the connection is closed normally. It raises a
 | 
					 | 
				
			||||||
        :exc:`~websockets.exceptions.ConnectionClosedError` exception after a
 | 
					 | 
				
			||||||
        protocol error or a network failure.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        try:
 | 
					 | 
				
			||||||
            while True:
 | 
					 | 
				
			||||||
                yield self.recv()
 | 
					 | 
				
			||||||
        except ConnectionClosedOK:
 | 
					 | 
				
			||||||
            return
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def recv(self, timeout: float | None = None) -> Data:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Receive the next message.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        When the connection is closed, :meth:`recv` raises
 | 
					 | 
				
			||||||
        :exc:`~websockets.exceptions.ConnectionClosed`. Specifically, it raises
 | 
					 | 
				
			||||||
        :exc:`~websockets.exceptions.ConnectionClosedOK` after a normal closure
 | 
					 | 
				
			||||||
        and :exc:`~websockets.exceptions.ConnectionClosedError` after a protocol
 | 
					 | 
				
			||||||
        error or a network failure. This is how you detect the end of the
 | 
					 | 
				
			||||||
        message stream.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        If ``timeout`` is :obj:`None`, block until a message is received. If
 | 
					 | 
				
			||||||
        ``timeout`` is set and no message is received within ``timeout``
 | 
					 | 
				
			||||||
        seconds, raise :exc:`TimeoutError`. Set ``timeout`` to ``0`` to check if
 | 
					 | 
				
			||||||
        a message was already received.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        If the message is fragmented, wait until all fragments are received,
 | 
					 | 
				
			||||||
        reassemble them, and return the whole message.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Returns:
 | 
					 | 
				
			||||||
            A string (:class:`str`) for a Text_ frame or a bytestring
 | 
					 | 
				
			||||||
            (:class:`bytes`) for a Binary_ frame.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6
 | 
					 | 
				
			||||||
            .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Raises:
 | 
					 | 
				
			||||||
            ConnectionClosed: When the connection is closed.
 | 
					 | 
				
			||||||
            ConcurrencyError: If two threads call :meth:`recv` or
 | 
					 | 
				
			||||||
                :meth:`recv_streaming` concurrently.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        try:
 | 
					 | 
				
			||||||
            return self.recv_messages.get(timeout)
 | 
					 | 
				
			||||||
        except EOFError:
 | 
					 | 
				
			||||||
            raise self.protocol.close_exc from self.recv_exc
 | 
					 | 
				
			||||||
        except ConcurrencyError:
 | 
					 | 
				
			||||||
            raise ConcurrencyError(
 | 
					 | 
				
			||||||
                "cannot call recv while another thread "
 | 
					 | 
				
			||||||
                "is already running recv or recv_streaming"
 | 
					 | 
				
			||||||
            ) from None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def recv_streaming(self) -> Iterator[Data]:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Receive the next message frame by frame.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        If the message is fragmented, yield each fragment as it is received.
 | 
					 | 
				
			||||||
        The iterator must be fully consumed, or else the connection will become
 | 
					 | 
				
			||||||
        unusable.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        :meth:`recv_streaming` raises the same exceptions as :meth:`recv`.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Returns:
 | 
					 | 
				
			||||||
            An iterator of strings (:class:`str`) for a Text_ frame or
 | 
					 | 
				
			||||||
            bytestrings (:class:`bytes`) for a Binary_ frame.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6
 | 
					 | 
				
			||||||
            .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Raises:
 | 
					 | 
				
			||||||
            ConnectionClosed: When the connection is closed.
 | 
					 | 
				
			||||||
            ConcurrencyError: If two threads call :meth:`recv` or
 | 
					 | 
				
			||||||
                :meth:`recv_streaming` concurrently.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        try:
 | 
					 | 
				
			||||||
            for frame in self.recv_messages.get_iter():
 | 
					 | 
				
			||||||
                yield frame
 | 
					 | 
				
			||||||
        except EOFError:
 | 
					 | 
				
			||||||
            raise self.protocol.close_exc from self.recv_exc
 | 
					 | 
				
			||||||
        except ConcurrencyError:
 | 
					 | 
				
			||||||
            raise ConcurrencyError(
 | 
					 | 
				
			||||||
                "cannot call recv_streaming while another thread "
 | 
					 | 
				
			||||||
                "is already running recv or recv_streaming"
 | 
					 | 
				
			||||||
            ) from None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def send(self, message: Data | Iterable[Data]) -> None:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Send a message.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        A string (:class:`str`) is sent as a Text_ frame. A bytestring or
 | 
					 | 
				
			||||||
        bytes-like object (:class:`bytes`, :class:`bytearray`, or
 | 
					 | 
				
			||||||
        :class:`memoryview`) is sent as a Binary_ frame.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6
 | 
					 | 
				
			||||||
        .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        :meth:`send` also accepts an iterable of strings, bytestrings, or
 | 
					 | 
				
			||||||
        bytes-like objects to enable fragmentation_. Each item is treated as a
 | 
					 | 
				
			||||||
        message fragment and sent in its own frame. All items must be of the
 | 
					 | 
				
			||||||
        same type, or else :meth:`send` will raise a :exc:`TypeError` and the
 | 
					 | 
				
			||||||
        connection will be closed.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        .. _fragmentation: https://datatracker.ietf.org/doc/html/rfc6455#section-5.4
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        :meth:`send` rejects dict-like objects because this is often an error.
 | 
					 | 
				
			||||||
        (If you really want to send the keys of a dict-like object as fragments,
 | 
					 | 
				
			||||||
        call its :meth:`~dict.keys` method and pass the result to :meth:`send`.)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        When the connection is closed, :meth:`send` raises
 | 
					 | 
				
			||||||
        :exc:`~websockets.exceptions.ConnectionClosed`. Specifically, it
 | 
					 | 
				
			||||||
        raises :exc:`~websockets.exceptions.ConnectionClosedOK` after a normal
 | 
					 | 
				
			||||||
        connection closure and
 | 
					 | 
				
			||||||
        :exc:`~websockets.exceptions.ConnectionClosedError` after a protocol
 | 
					 | 
				
			||||||
        error or a network failure.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Args:
 | 
					 | 
				
			||||||
            message: Message to send.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Raises:
 | 
					 | 
				
			||||||
            ConnectionClosed: When the connection is closed.
 | 
					 | 
				
			||||||
            ConcurrencyError: If the connection is sending a fragmented message.
 | 
					 | 
				
			||||||
            TypeError: If ``message`` doesn't have a supported type.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        # Unfragmented message -- this case must be handled first because
 | 
					 | 
				
			||||||
        # strings and bytes-like objects are iterable.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if isinstance(message, str):
 | 
					 | 
				
			||||||
            with self.send_context():
 | 
					 | 
				
			||||||
                if self.send_in_progress:
 | 
					 | 
				
			||||||
                    raise ConcurrencyError(
 | 
					 | 
				
			||||||
                        "cannot call send while another thread "
 | 
					 | 
				
			||||||
                        "is already running send"
 | 
					 | 
				
			||||||
                    )
 | 
					 | 
				
			||||||
                self.protocol.send_text(message.encode())
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        elif isinstance(message, BytesLike):
 | 
					 | 
				
			||||||
            with self.send_context():
 | 
					 | 
				
			||||||
                if self.send_in_progress:
 | 
					 | 
				
			||||||
                    raise ConcurrencyError(
 | 
					 | 
				
			||||||
                        "cannot call send while another thread "
 | 
					 | 
				
			||||||
                        "is already running send"
 | 
					 | 
				
			||||||
                    )
 | 
					 | 
				
			||||||
                self.protocol.send_binary(message)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Catch a common mistake -- passing a dict to send().
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        elif isinstance(message, Mapping):
 | 
					 | 
				
			||||||
            raise TypeError("data is a dict-like object")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Fragmented message -- regular iterator.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        elif isinstance(message, Iterable):
 | 
					 | 
				
			||||||
            chunks = iter(message)
 | 
					 | 
				
			||||||
            try:
 | 
					 | 
				
			||||||
                chunk = next(chunks)
 | 
					 | 
				
			||||||
            except StopIteration:
 | 
					 | 
				
			||||||
                return
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            try:
 | 
					 | 
				
			||||||
                # First fragment.
 | 
					 | 
				
			||||||
                if isinstance(chunk, str):
 | 
					 | 
				
			||||||
                    text = True
 | 
					 | 
				
			||||||
                    with self.send_context():
 | 
					 | 
				
			||||||
                        if self.send_in_progress:
 | 
					 | 
				
			||||||
                            raise ConcurrencyError(
 | 
					 | 
				
			||||||
                                "cannot call send while another thread "
 | 
					 | 
				
			||||||
                                "is already running send"
 | 
					 | 
				
			||||||
                            )
 | 
					 | 
				
			||||||
                        self.send_in_progress = True
 | 
					 | 
				
			||||||
                        self.protocol.send_text(
 | 
					 | 
				
			||||||
                            chunk.encode(),
 | 
					 | 
				
			||||||
                            fin=False,
 | 
					 | 
				
			||||||
                        )
 | 
					 | 
				
			||||||
                elif isinstance(chunk, BytesLike):
 | 
					 | 
				
			||||||
                    text = False
 | 
					 | 
				
			||||||
                    with self.send_context():
 | 
					 | 
				
			||||||
                        if self.send_in_progress:
 | 
					 | 
				
			||||||
                            raise ConcurrencyError(
 | 
					 | 
				
			||||||
                                "cannot call send while another thread "
 | 
					 | 
				
			||||||
                                "is already running send"
 | 
					 | 
				
			||||||
                            )
 | 
					 | 
				
			||||||
                        self.send_in_progress = True
 | 
					 | 
				
			||||||
                        self.protocol.send_binary(
 | 
					 | 
				
			||||||
                            chunk,
 | 
					 | 
				
			||||||
                            fin=False,
 | 
					 | 
				
			||||||
                        )
 | 
					 | 
				
			||||||
                else:
 | 
					 | 
				
			||||||
                    raise TypeError("data iterable must contain bytes or str")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                # Other fragments
 | 
					 | 
				
			||||||
                for chunk in chunks:
 | 
					 | 
				
			||||||
                    if isinstance(chunk, str) and text:
 | 
					 | 
				
			||||||
                        with self.send_context():
 | 
					 | 
				
			||||||
                            assert self.send_in_progress
 | 
					 | 
				
			||||||
                            self.protocol.send_continuation(
 | 
					 | 
				
			||||||
                                chunk.encode(),
 | 
					 | 
				
			||||||
                                fin=False,
 | 
					 | 
				
			||||||
                            )
 | 
					 | 
				
			||||||
                    elif isinstance(chunk, BytesLike) and not text:
 | 
					 | 
				
			||||||
                        with self.send_context():
 | 
					 | 
				
			||||||
                            assert self.send_in_progress
 | 
					 | 
				
			||||||
                            self.protocol.send_continuation(
 | 
					 | 
				
			||||||
                                chunk,
 | 
					 | 
				
			||||||
                                fin=False,
 | 
					 | 
				
			||||||
                            )
 | 
					 | 
				
			||||||
                    else:
 | 
					 | 
				
			||||||
                        raise TypeError("data iterable must contain uniform types")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                # Final fragment.
 | 
					 | 
				
			||||||
                with self.send_context():
 | 
					 | 
				
			||||||
                    self.protocol.send_continuation(b"", fin=True)
 | 
					 | 
				
			||||||
                    self.send_in_progress = False
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            except ConcurrencyError:
 | 
					 | 
				
			||||||
                # We didn't start sending a fragmented message.
 | 
					 | 
				
			||||||
                # The connection is still usable.
 | 
					 | 
				
			||||||
                raise
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            except Exception:
 | 
					 | 
				
			||||||
                # We're half-way through a fragmented message and we can't
 | 
					 | 
				
			||||||
                # complete it. This makes the connection unusable.
 | 
					 | 
				
			||||||
                with self.send_context():
 | 
					 | 
				
			||||||
                    self.protocol.fail(
 | 
					 | 
				
			||||||
                        CloseCode.INTERNAL_ERROR,
 | 
					 | 
				
			||||||
                        "error in fragmented message",
 | 
					 | 
				
			||||||
                    )
 | 
					 | 
				
			||||||
                raise
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            raise TypeError("data must be str, bytes, or iterable")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def close(self, code: int = CloseCode.NORMAL_CLOSURE, reason: str = "") -> None:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Perform the closing handshake.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        :meth:`close` waits for the other end to complete the handshake, for the
 | 
					 | 
				
			||||||
        TCP connection to terminate, and for all incoming messages to be read
 | 
					 | 
				
			||||||
        with :meth:`recv`.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        :meth:`close` is idempotent: it doesn't do anything once the
 | 
					 | 
				
			||||||
        connection is closed.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Args:
 | 
					 | 
				
			||||||
            code: WebSocket close code.
 | 
					 | 
				
			||||||
            reason: WebSocket close reason.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        try:
 | 
					 | 
				
			||||||
            # The context manager takes care of waiting for the TCP connection
 | 
					 | 
				
			||||||
            # to terminate after calling a method that sends a close frame.
 | 
					 | 
				
			||||||
            with self.send_context():
 | 
					 | 
				
			||||||
                if self.send_in_progress:
 | 
					 | 
				
			||||||
                    self.protocol.fail(
 | 
					 | 
				
			||||||
                        CloseCode.INTERNAL_ERROR,
 | 
					 | 
				
			||||||
                        "close during fragmented message",
 | 
					 | 
				
			||||||
                    )
 | 
					 | 
				
			||||||
                else:
 | 
					 | 
				
			||||||
                    self.protocol.send_close(code, reason)
 | 
					 | 
				
			||||||
        except ConnectionClosed:
 | 
					 | 
				
			||||||
            # Ignore ConnectionClosed exceptions raised from send_context().
 | 
					 | 
				
			||||||
            # They mean that the connection is closed, which was the goal.
 | 
					 | 
				
			||||||
            pass
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def ping(self, data: Data | None = None) -> threading.Event:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Send a Ping_.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        .. _Ping: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.2
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        A ping may serve as a keepalive or as a check that the remote endpoint
 | 
					 | 
				
			||||||
        received all messages up to this point
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Args:
 | 
					 | 
				
			||||||
            data: Payload of the ping. A :class:`str` will be encoded to UTF-8.
 | 
					 | 
				
			||||||
                If ``data`` is :obj:`None`, the payload is four random bytes.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Returns:
 | 
					 | 
				
			||||||
            An event that will be set when the corresponding pong is received.
 | 
					 | 
				
			||||||
            You can ignore it if you don't intend to wait.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            ::
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                pong_event = ws.ping()
 | 
					 | 
				
			||||||
                pong_event.wait()  # only if you want to wait for the pong
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Raises:
 | 
					 | 
				
			||||||
            ConnectionClosed: When the connection is closed.
 | 
					 | 
				
			||||||
            ConcurrencyError: If another ping was sent with the same data and
 | 
					 | 
				
			||||||
                the corresponding pong wasn't received yet.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        if isinstance(data, BytesLike):
 | 
					 | 
				
			||||||
            data = bytes(data)
 | 
					 | 
				
			||||||
        elif isinstance(data, str):
 | 
					 | 
				
			||||||
            data = data.encode()
 | 
					 | 
				
			||||||
        elif data is not None:
 | 
					 | 
				
			||||||
            raise TypeError("data must be str or bytes-like")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        with self.send_context():
 | 
					 | 
				
			||||||
            # Protect against duplicates if a payload is explicitly set.
 | 
					 | 
				
			||||||
            if data in self.ping_waiters:
 | 
					 | 
				
			||||||
                raise ConcurrencyError("already waiting for a pong with the same data")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            # Generate a unique random payload otherwise.
 | 
					 | 
				
			||||||
            while data is None or data in self.ping_waiters:
 | 
					 | 
				
			||||||
                data = struct.pack("!I", random.getrandbits(32))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            pong_waiter = threading.Event()
 | 
					 | 
				
			||||||
            self.ping_waiters[data] = pong_waiter
 | 
					 | 
				
			||||||
            self.protocol.send_ping(data)
 | 
					 | 
				
			||||||
            return pong_waiter
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def pong(self, data: Data = b"") -> None:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Send a Pong_.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        .. _Pong: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.3
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        An unsolicited pong may serve as a unidirectional heartbeat.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Args:
 | 
					 | 
				
			||||||
            data: Payload of the pong. A :class:`str` will be encoded to UTF-8.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Raises:
 | 
					 | 
				
			||||||
            ConnectionClosed: When the connection is closed.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        if isinstance(data, BytesLike):
 | 
					 | 
				
			||||||
            data = bytes(data)
 | 
					 | 
				
			||||||
        elif isinstance(data, str):
 | 
					 | 
				
			||||||
            data = data.encode()
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            raise TypeError("data must be str or bytes-like")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        with self.send_context():
 | 
					 | 
				
			||||||
            self.protocol.send_pong(data)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Private methods
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def process_event(self, event: Event) -> None:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Process one incoming event.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        This method is overridden in subclasses to handle the handshake.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        assert isinstance(event, Frame)
 | 
					 | 
				
			||||||
        if event.opcode in DATA_OPCODES:
 | 
					 | 
				
			||||||
            self.recv_messages.put(event)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if event.opcode is Opcode.PONG:
 | 
					 | 
				
			||||||
            self.acknowledge_pings(bytes(event.data))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def acknowledge_pings(self, data: bytes) -> None:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Acknowledge pings when receiving a pong.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        with self.protocol_mutex:
 | 
					 | 
				
			||||||
            # Ignore unsolicited pong.
 | 
					 | 
				
			||||||
            if data not in self.ping_waiters:
 | 
					 | 
				
			||||||
                return
 | 
					 | 
				
			||||||
            # Sending a pong for only the most recent ping is legal.
 | 
					 | 
				
			||||||
            # Acknowledge all previous pings too in that case.
 | 
					 | 
				
			||||||
            ping_id = None
 | 
					 | 
				
			||||||
            ping_ids = []
 | 
					 | 
				
			||||||
            for ping_id, ping in self.ping_waiters.items():
 | 
					 | 
				
			||||||
                ping_ids.append(ping_id)
 | 
					 | 
				
			||||||
                ping.set()
 | 
					 | 
				
			||||||
                if ping_id == data:
 | 
					 | 
				
			||||||
                    break
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                raise AssertionError("solicited pong not found in pings")
 | 
					 | 
				
			||||||
            # Remove acknowledged pings from self.ping_waiters.
 | 
					 | 
				
			||||||
            for ping_id in ping_ids:
 | 
					 | 
				
			||||||
                del self.ping_waiters[ping_id]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def recv_events(self) -> None:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Read incoming data from the socket and process events.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Run this method in a thread as long as the connection is alive.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        ``recv_events()`` exits immediately when the ``self.socket`` is closed.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        try:
 | 
					 | 
				
			||||||
            while True:
 | 
					 | 
				
			||||||
                try:
 | 
					 | 
				
			||||||
                    if self.close_deadline is not None:
 | 
					 | 
				
			||||||
                        self.socket.settimeout(self.close_deadline.timeout())
 | 
					 | 
				
			||||||
                    data = self.socket.recv(self.recv_bufsize)
 | 
					 | 
				
			||||||
                except Exception as exc:
 | 
					 | 
				
			||||||
                    if self.debug:
 | 
					 | 
				
			||||||
                        self.logger.debug("error while receiving data", exc_info=True)
 | 
					 | 
				
			||||||
                    # When the closing handshake is initiated by our side,
 | 
					 | 
				
			||||||
                    # recv() may block until send_context() closes the socket.
 | 
					 | 
				
			||||||
                    # In that case, send_context() already set recv_exc.
 | 
					 | 
				
			||||||
                    # Calling set_recv_exc() avoids overwriting it.
 | 
					 | 
				
			||||||
                    with self.protocol_mutex:
 | 
					 | 
				
			||||||
                        self.set_recv_exc(exc)
 | 
					 | 
				
			||||||
                    break
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                if data == b"":
 | 
					 | 
				
			||||||
                    break
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                # Acquire the connection lock.
 | 
					 | 
				
			||||||
                with self.protocol_mutex:
 | 
					 | 
				
			||||||
                    # Feed incoming data to the protocol.
 | 
					 | 
				
			||||||
                    self.protocol.receive_data(data)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                    # This isn't expected to raise an exception.
 | 
					 | 
				
			||||||
                    events = self.protocol.events_received()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                    # Write outgoing data to the socket.
 | 
					 | 
				
			||||||
                    try:
 | 
					 | 
				
			||||||
                        self.send_data()
 | 
					 | 
				
			||||||
                    except Exception as exc:
 | 
					 | 
				
			||||||
                        if self.debug:
 | 
					 | 
				
			||||||
                            self.logger.debug("error while sending data", exc_info=True)
 | 
					 | 
				
			||||||
                        # Similarly to the above, avoid overriding an exception
 | 
					 | 
				
			||||||
                        # set by send_context(), in case of a race condition
 | 
					 | 
				
			||||||
                        # i.e. send_context() closes the socket after recv()
 | 
					 | 
				
			||||||
                        # returns above but before send_data() calls send().
 | 
					 | 
				
			||||||
                        self.set_recv_exc(exc)
 | 
					 | 
				
			||||||
                        break
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                    if self.protocol.close_expected():
 | 
					 | 
				
			||||||
                        # If the connection is expected to close soon, set the
 | 
					 | 
				
			||||||
                        # close deadline based on the close timeout.
 | 
					 | 
				
			||||||
                        if self.close_deadline is None:
 | 
					 | 
				
			||||||
                            self.close_deadline = Deadline(self.close_timeout)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                # Unlock conn_mutex before processing events. Else, the
 | 
					 | 
				
			||||||
                # application can't send messages in response to events.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                # If self.send_data raised an exception, then events are lost.
 | 
					 | 
				
			||||||
                # Given that automatic responses write small amounts of data,
 | 
					 | 
				
			||||||
                # this should be uncommon, so we don't handle the edge case.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                try:
 | 
					 | 
				
			||||||
                    for event in events:
 | 
					 | 
				
			||||||
                        # This may raise EOFError if the closing handshake
 | 
					 | 
				
			||||||
                        # times out while a message is waiting to be read.
 | 
					 | 
				
			||||||
                        self.process_event(event)
 | 
					 | 
				
			||||||
                except EOFError:
 | 
					 | 
				
			||||||
                    break
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            # Breaking out of the while True: ... loop means that we believe
 | 
					 | 
				
			||||||
            # that the socket doesn't work anymore.
 | 
					 | 
				
			||||||
            with self.protocol_mutex:
 | 
					 | 
				
			||||||
                # Feed the end of the data stream to the protocol.
 | 
					 | 
				
			||||||
                self.protocol.receive_eof()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                # This isn't expected to generate events.
 | 
					 | 
				
			||||||
                assert not self.protocol.events_received()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                # There is no error handling because send_data() can only write
 | 
					 | 
				
			||||||
                # the end of the data stream here and it handles errors itself.
 | 
					 | 
				
			||||||
                self.send_data()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        except Exception as exc:
 | 
					 | 
				
			||||||
            # This branch should never run. It's a safety net in case of bugs.
 | 
					 | 
				
			||||||
            self.logger.error("unexpected internal error", exc_info=True)
 | 
					 | 
				
			||||||
            with self.protocol_mutex:
 | 
					 | 
				
			||||||
                self.set_recv_exc(exc)
 | 
					 | 
				
			||||||
            # We don't know where we crashed. Force protocol state to CLOSED.
 | 
					 | 
				
			||||||
            self.protocol.state = CLOSED
 | 
					 | 
				
			||||||
        finally:
 | 
					 | 
				
			||||||
            # This isn't expected to raise an exception.
 | 
					 | 
				
			||||||
            self.close_socket()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @contextlib.contextmanager
 | 
					 | 
				
			||||||
    def send_context(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        *,
 | 
					 | 
				
			||||||
        expected_state: State = OPEN,  # CONNECTING during the opening handshake
 | 
					 | 
				
			||||||
    ) -> Iterator[None]:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Create a context for writing to the connection from user code.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        On entry, :meth:`send_context` acquires the connection lock and checks
 | 
					 | 
				
			||||||
        that the connection is open; on exit, it writes outgoing data to the
 | 
					 | 
				
			||||||
        socket::
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            with self.send_context():
 | 
					 | 
				
			||||||
                self.protocol.send_text(message.encode())
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        When the connection isn't open on entry, when the connection is expected
 | 
					 | 
				
			||||||
        to close on exit, or when an unexpected error happens, terminating the
 | 
					 | 
				
			||||||
        connection, :meth:`send_context` waits until the connection is closed
 | 
					 | 
				
			||||||
        then raises :exc:`~websockets.exceptions.ConnectionClosed`.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        # Should we wait until the connection is closed?
 | 
					 | 
				
			||||||
        wait_for_close = False
 | 
					 | 
				
			||||||
        # Should we close the socket and raise ConnectionClosed?
 | 
					 | 
				
			||||||
        raise_close_exc = False
 | 
					 | 
				
			||||||
        # What exception should we chain ConnectionClosed to?
 | 
					 | 
				
			||||||
        original_exc: BaseException | None = None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Acquire the protocol lock.
 | 
					 | 
				
			||||||
        with self.protocol_mutex:
 | 
					 | 
				
			||||||
            if self.protocol.state is expected_state:
 | 
					 | 
				
			||||||
                # Let the caller interact with the protocol.
 | 
					 | 
				
			||||||
                try:
 | 
					 | 
				
			||||||
                    yield
 | 
					 | 
				
			||||||
                except (ProtocolError, ConcurrencyError):
 | 
					 | 
				
			||||||
                    # The protocol state wasn't changed. Exit immediately.
 | 
					 | 
				
			||||||
                    raise
 | 
					 | 
				
			||||||
                except Exception as exc:
 | 
					 | 
				
			||||||
                    self.logger.error("unexpected internal error", exc_info=True)
 | 
					 | 
				
			||||||
                    # This branch should never run. It's a safety net in case of
 | 
					 | 
				
			||||||
                    # bugs. Since we don't know what happened, we will close the
 | 
					 | 
				
			||||||
                    # connection and raise the exception to the caller.
 | 
					 | 
				
			||||||
                    wait_for_close = False
 | 
					 | 
				
			||||||
                    raise_close_exc = True
 | 
					 | 
				
			||||||
                    original_exc = exc
 | 
					 | 
				
			||||||
                else:
 | 
					 | 
				
			||||||
                    # Check if the connection is expected to close soon.
 | 
					 | 
				
			||||||
                    if self.protocol.close_expected():
 | 
					 | 
				
			||||||
                        wait_for_close = True
 | 
					 | 
				
			||||||
                        # If the connection is expected to close soon, set the
 | 
					 | 
				
			||||||
                        # close deadline based on the close timeout.
 | 
					 | 
				
			||||||
                        # Since we tested earlier that protocol.state was OPEN
 | 
					 | 
				
			||||||
                        # (or CONNECTING) and we didn't release protocol_mutex,
 | 
					 | 
				
			||||||
                        # it is certain that self.close_deadline is still None.
 | 
					 | 
				
			||||||
                        assert self.close_deadline is None
 | 
					 | 
				
			||||||
                        self.close_deadline = Deadline(self.close_timeout)
 | 
					 | 
				
			||||||
                    # Write outgoing data to the socket.
 | 
					 | 
				
			||||||
                    try:
 | 
					 | 
				
			||||||
                        self.send_data()
 | 
					 | 
				
			||||||
                    except Exception as exc:
 | 
					 | 
				
			||||||
                        if self.debug:
 | 
					 | 
				
			||||||
                            self.logger.debug("error while sending data", exc_info=True)
 | 
					 | 
				
			||||||
                        # While the only expected exception here is OSError,
 | 
					 | 
				
			||||||
                        # other exceptions would be treated identically.
 | 
					 | 
				
			||||||
                        wait_for_close = False
 | 
					 | 
				
			||||||
                        raise_close_exc = True
 | 
					 | 
				
			||||||
                        original_exc = exc
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            else:  # self.protocol.state is not expected_state
 | 
					 | 
				
			||||||
                # Minor layering violation: we assume that the connection
 | 
					 | 
				
			||||||
                # will be closing soon if it isn't in the expected state.
 | 
					 | 
				
			||||||
                wait_for_close = True
 | 
					 | 
				
			||||||
                raise_close_exc = True
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # To avoid a deadlock, release the connection lock by exiting the
 | 
					 | 
				
			||||||
        # context manager before waiting for recv_events() to terminate.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # If the connection is expected to close soon and the close timeout
 | 
					 | 
				
			||||||
        # elapses, close the socket to terminate the connection.
 | 
					 | 
				
			||||||
        if wait_for_close:
 | 
					 | 
				
			||||||
            if self.close_deadline is None:
 | 
					 | 
				
			||||||
                timeout = self.close_timeout
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                # Thread.join() returns immediately if timeout is negative.
 | 
					 | 
				
			||||||
                timeout = self.close_deadline.timeout(raise_if_elapsed=False)
 | 
					 | 
				
			||||||
            self.recv_events_thread.join(timeout)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            if self.recv_events_thread.is_alive():
 | 
					 | 
				
			||||||
                # There's no risk to overwrite another error because
 | 
					 | 
				
			||||||
                # original_exc is never set when wait_for_close is True.
 | 
					 | 
				
			||||||
                assert original_exc is None
 | 
					 | 
				
			||||||
                original_exc = TimeoutError("timed out while closing connection")
 | 
					 | 
				
			||||||
                # Set recv_exc before closing the socket in order to get
 | 
					 | 
				
			||||||
                # proper exception reporting.
 | 
					 | 
				
			||||||
                raise_close_exc = True
 | 
					 | 
				
			||||||
                with self.protocol_mutex:
 | 
					 | 
				
			||||||
                    self.set_recv_exc(original_exc)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # If an error occurred, close the socket to terminate the connection and
 | 
					 | 
				
			||||||
        # raise an exception.
 | 
					 | 
				
			||||||
        if raise_close_exc:
 | 
					 | 
				
			||||||
            self.close_socket()
 | 
					 | 
				
			||||||
            self.recv_events_thread.join()
 | 
					 | 
				
			||||||
            raise self.protocol.close_exc from original_exc
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def send_data(self) -> None:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Send outgoing data.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        This method requires holding protocol_mutex.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Raises:
 | 
					 | 
				
			||||||
            OSError: When a socket operations fails.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        assert self.protocol_mutex.locked()
 | 
					 | 
				
			||||||
        for data in self.protocol.data_to_send():
 | 
					 | 
				
			||||||
            if data:
 | 
					 | 
				
			||||||
                if self.close_deadline is not None:
 | 
					 | 
				
			||||||
                    self.socket.settimeout(self.close_deadline.timeout())
 | 
					 | 
				
			||||||
                self.socket.sendall(data)
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                try:
 | 
					 | 
				
			||||||
                    self.socket.shutdown(socket.SHUT_WR)
 | 
					 | 
				
			||||||
                except OSError:  # socket already closed
 | 
					 | 
				
			||||||
                    pass
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def set_recv_exc(self, exc: BaseException | None) -> None:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Set recv_exc, if not set yet.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        This method requires holding protocol_mutex.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        assert self.protocol_mutex.locked()
 | 
					 | 
				
			||||||
        if self.recv_exc is None:  # pragma: no branch
 | 
					 | 
				
			||||||
            self.recv_exc = exc
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def close_socket(self) -> None:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Shutdown and close socket. Close message assembler.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Calling close_socket() guarantees that recv_events() terminates. Indeed,
 | 
					 | 
				
			||||||
        recv_events() may block only on socket.recv() or on recv_messages.put().
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        # shutdown() is required to interrupt recv() on Linux.
 | 
					 | 
				
			||||||
        try:
 | 
					 | 
				
			||||||
            self.socket.shutdown(socket.SHUT_RDWR)
 | 
					 | 
				
			||||||
        except OSError:
 | 
					 | 
				
			||||||
            pass  # socket is already closed
 | 
					 | 
				
			||||||
        self.socket.close()
 | 
					 | 
				
			||||||
        self.recv_messages.close()
 | 
					 | 
				
			||||||
@@ -1,283 +0,0 @@
 | 
				
			|||||||
from __future__ import annotations
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import codecs
 | 
					 | 
				
			||||||
import queue
 | 
					 | 
				
			||||||
import threading
 | 
					 | 
				
			||||||
from typing import Iterator, cast
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from ..exceptions import ConcurrencyError
 | 
					 | 
				
			||||||
from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame
 | 
					 | 
				
			||||||
from ..typing import Data
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
__all__ = ["Assembler"]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
UTF8Decoder = codecs.getincrementaldecoder("utf-8")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class Assembler:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Assemble messages from frames.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(self) -> None:
 | 
					 | 
				
			||||||
        # Serialize reads and writes -- except for reads via synchronization
 | 
					 | 
				
			||||||
        # primitives provided by the threading and queue modules.
 | 
					 | 
				
			||||||
        self.mutex = threading.Lock()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # We create a latch with two events to synchronize the production of
 | 
					 | 
				
			||||||
        # frames and the consumption of messages (or frames) without a buffer.
 | 
					 | 
				
			||||||
        # This design requires a switch between the library thread and the user
 | 
					 | 
				
			||||||
        # thread for each message; that shouldn't be a performance bottleneck.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # put() sets this event to tell get() that a message can be fetched.
 | 
					 | 
				
			||||||
        self.message_complete = threading.Event()
 | 
					 | 
				
			||||||
        # get() sets this event to let put() that the message was fetched.
 | 
					 | 
				
			||||||
        self.message_fetched = threading.Event()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # This flag prevents concurrent calls to get() by user code.
 | 
					 | 
				
			||||||
        self.get_in_progress = False
 | 
					 | 
				
			||||||
        # This flag prevents concurrent calls to put() by library code.
 | 
					 | 
				
			||||||
        self.put_in_progress = False
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Decoder for text frames, None for binary frames.
 | 
					 | 
				
			||||||
        self.decoder: codecs.IncrementalDecoder | None = None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Buffer of frames belonging to the same message.
 | 
					 | 
				
			||||||
        self.chunks: list[Data] = []
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # When switching from "buffering" to "streaming", we use a thread-safe
 | 
					 | 
				
			||||||
        # queue for transferring frames from the writing thread (library code)
 | 
					 | 
				
			||||||
        # to the reading thread (user code). We're buffering when chunks_queue
 | 
					 | 
				
			||||||
        # is None and streaming when it's a SimpleQueue. None is a sentinel
 | 
					 | 
				
			||||||
        # value marking the end of the message, superseding message_complete.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Stream data from frames belonging to the same message.
 | 
					 | 
				
			||||||
        self.chunks_queue: queue.SimpleQueue[Data | None] | None = None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # This flag marks the end of the connection.
 | 
					 | 
				
			||||||
        self.closed = False
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def get(self, timeout: float | None = None) -> Data:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Read the next message.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        :meth:`get` returns a single :class:`str` or :class:`bytes`.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        If the message is fragmented, :meth:`get` waits until the last frame is
 | 
					 | 
				
			||||||
        received, then it reassembles the message and returns it. To receive
 | 
					 | 
				
			||||||
        messages frame by frame, use :meth:`get_iter` instead.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Args:
 | 
					 | 
				
			||||||
            timeout: If a timeout is provided and elapses before a complete
 | 
					 | 
				
			||||||
                message is received, :meth:`get` raises :exc:`TimeoutError`.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Raises:
 | 
					 | 
				
			||||||
            EOFError: If the stream of frames has ended.
 | 
					 | 
				
			||||||
            ConcurrencyError: If two threads run :meth:`get` or :meth:`get_iter`
 | 
					 | 
				
			||||||
                concurrently.
 | 
					 | 
				
			||||||
            TimeoutError: If a timeout is provided and elapses before a
 | 
					 | 
				
			||||||
                complete message is received.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        with self.mutex:
 | 
					 | 
				
			||||||
            if self.closed:
 | 
					 | 
				
			||||||
                raise EOFError("stream of frames ended")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            if self.get_in_progress:
 | 
					 | 
				
			||||||
                raise ConcurrencyError("get() or get_iter() is already running")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            self.get_in_progress = True
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # If the message_complete event isn't set yet, release the lock to
 | 
					 | 
				
			||||||
        # allow put() to run and eventually set it.
 | 
					 | 
				
			||||||
        # Locking with get_in_progress ensures only one thread can get here.
 | 
					 | 
				
			||||||
        completed = self.message_complete.wait(timeout)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        with self.mutex:
 | 
					 | 
				
			||||||
            self.get_in_progress = False
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            # Waiting for a complete message timed out.
 | 
					 | 
				
			||||||
            if not completed:
 | 
					 | 
				
			||||||
                raise TimeoutError(f"timed out in {timeout:.1f}s")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            # get() was unblocked by close() rather than put().
 | 
					 | 
				
			||||||
            if self.closed:
 | 
					 | 
				
			||||||
                raise EOFError("stream of frames ended")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            assert self.message_complete.is_set()
 | 
					 | 
				
			||||||
            self.message_complete.clear()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            joiner: Data = b"" if self.decoder is None else ""
 | 
					 | 
				
			||||||
            # mypy cannot figure out that chunks have the proper type.
 | 
					 | 
				
			||||||
            message: Data = joiner.join(self.chunks)  # type: ignore
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            self.chunks = []
 | 
					 | 
				
			||||||
            assert self.chunks_queue is None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            assert not self.message_fetched.is_set()
 | 
					 | 
				
			||||||
            self.message_fetched.set()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            return message
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def get_iter(self) -> Iterator[Data]:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Stream the next message.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Iterating the return value of :meth:`get_iter` yields a :class:`str` or
 | 
					 | 
				
			||||||
        :class:`bytes` for each frame in the message.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        The iterator must be fully consumed before calling :meth:`get_iter` or
 | 
					 | 
				
			||||||
        :meth:`get` again. Else, :exc:`ConcurrencyError` is raised.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        This method only makes sense for fragmented messages. If messages aren't
 | 
					 | 
				
			||||||
        fragmented, use :meth:`get` instead.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Raises:
 | 
					 | 
				
			||||||
            EOFError: If the stream of frames has ended.
 | 
					 | 
				
			||||||
            ConcurrencyError: If two threads run :meth:`get` or :meth:`get_iter`
 | 
					 | 
				
			||||||
                concurrently.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        with self.mutex:
 | 
					 | 
				
			||||||
            if self.closed:
 | 
					 | 
				
			||||||
                raise EOFError("stream of frames ended")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            if self.get_in_progress:
 | 
					 | 
				
			||||||
                raise ConcurrencyError("get() or get_iter() is already running")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            chunks = self.chunks
 | 
					 | 
				
			||||||
            self.chunks = []
 | 
					 | 
				
			||||||
            self.chunks_queue = cast(
 | 
					 | 
				
			||||||
                # Remove quotes around type when dropping Python < 3.9.
 | 
					 | 
				
			||||||
                "queue.SimpleQueue[Data | None]",
 | 
					 | 
				
			||||||
                queue.SimpleQueue(),
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            # Sending None in chunk_queue supersedes setting message_complete
 | 
					 | 
				
			||||||
            # when switching to "streaming". If message is already complete
 | 
					 | 
				
			||||||
            # when the switch happens, put() didn't send None, so we have to.
 | 
					 | 
				
			||||||
            if self.message_complete.is_set():
 | 
					 | 
				
			||||||
                self.chunks_queue.put(None)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            self.get_in_progress = True
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Locking with get_in_progress ensures only one thread can get here.
 | 
					 | 
				
			||||||
        chunk: Data | None
 | 
					 | 
				
			||||||
        for chunk in chunks:
 | 
					 | 
				
			||||||
            yield chunk
 | 
					 | 
				
			||||||
        while (chunk := self.chunks_queue.get()) is not None:
 | 
					 | 
				
			||||||
            yield chunk
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        with self.mutex:
 | 
					 | 
				
			||||||
            self.get_in_progress = False
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            # get_iter() was unblocked by close() rather than put().
 | 
					 | 
				
			||||||
            if self.closed:
 | 
					 | 
				
			||||||
                raise EOFError("stream of frames ended")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            assert self.message_complete.is_set()
 | 
					 | 
				
			||||||
            self.message_complete.clear()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            assert self.chunks == []
 | 
					 | 
				
			||||||
            self.chunks_queue = None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            assert not self.message_fetched.is_set()
 | 
					 | 
				
			||||||
            self.message_fetched.set()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def put(self, frame: Frame) -> None:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Add ``frame`` to the next message.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        When ``frame`` is the final frame in a message, :meth:`put` waits until
 | 
					 | 
				
			||||||
        the message is fetched, which can be achieved by calling :meth:`get` or
 | 
					 | 
				
			||||||
        by fully consuming the return value of :meth:`get_iter`.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        :meth:`put` assumes that the stream of frames respects the protocol. If
 | 
					 | 
				
			||||||
        it doesn't, the behavior is undefined.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Raises:
 | 
					 | 
				
			||||||
            EOFError: If the stream of frames has ended.
 | 
					 | 
				
			||||||
            ConcurrencyError: If two threads run :meth:`put` concurrently.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        with self.mutex:
 | 
					 | 
				
			||||||
            if self.closed:
 | 
					 | 
				
			||||||
                raise EOFError("stream of frames ended")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            if self.put_in_progress:
 | 
					 | 
				
			||||||
                raise ConcurrencyError("put is already running")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            if frame.opcode is OP_TEXT:
 | 
					 | 
				
			||||||
                self.decoder = UTF8Decoder(errors="strict")
 | 
					 | 
				
			||||||
            elif frame.opcode is OP_BINARY:
 | 
					 | 
				
			||||||
                self.decoder = None
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                assert frame.opcode is OP_CONT
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            data: Data
 | 
					 | 
				
			||||||
            if self.decoder is not None:
 | 
					 | 
				
			||||||
                data = self.decoder.decode(frame.data, frame.fin)
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                data = frame.data
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            if self.chunks_queue is None:
 | 
					 | 
				
			||||||
                self.chunks.append(data)
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                self.chunks_queue.put(data)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            if not frame.fin:
 | 
					 | 
				
			||||||
                return
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            # Message is complete. Wait until it's fetched to return.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            assert not self.message_complete.is_set()
 | 
					 | 
				
			||||||
            self.message_complete.set()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            if self.chunks_queue is not None:
 | 
					 | 
				
			||||||
                self.chunks_queue.put(None)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            assert not self.message_fetched.is_set()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            self.put_in_progress = True
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Release the lock to allow get() to run and eventually set the event.
 | 
					 | 
				
			||||||
        # Locking with put_in_progress ensures only one coroutine can get here.
 | 
					 | 
				
			||||||
        self.message_fetched.wait()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        with self.mutex:
 | 
					 | 
				
			||||||
            self.put_in_progress = False
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            # put() was unblocked by close() rather than get() or get_iter().
 | 
					 | 
				
			||||||
            if self.closed:
 | 
					 | 
				
			||||||
                raise EOFError("stream of frames ended")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            assert self.message_fetched.is_set()
 | 
					 | 
				
			||||||
            self.message_fetched.clear()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            self.decoder = None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def close(self) -> None:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        End the stream of frames.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Callling :meth:`close` concurrently with :meth:`get`, :meth:`get_iter`,
 | 
					 | 
				
			||||||
        or :meth:`put` is safe. They will raise :exc:`EOFError`.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        with self.mutex:
 | 
					 | 
				
			||||||
            if self.closed:
 | 
					 | 
				
			||||||
                return
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            self.closed = True
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            # Unblock get or get_iter.
 | 
					 | 
				
			||||||
            if self.get_in_progress:
 | 
					 | 
				
			||||||
                self.message_complete.set()
 | 
					 | 
				
			||||||
                if self.chunks_queue is not None:
 | 
					 | 
				
			||||||
                    self.chunks_queue.put(None)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            # Unblock put().
 | 
					 | 
				
			||||||
            if self.put_in_progress:
 | 
					 | 
				
			||||||
                self.message_fetched.set()
 | 
					 | 
				
			||||||
@@ -1,727 +0,0 @@
 | 
				
			|||||||
from __future__ import annotations
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import hmac
 | 
					 | 
				
			||||||
import http
 | 
					 | 
				
			||||||
import logging
 | 
					 | 
				
			||||||
import os
 | 
					 | 
				
			||||||
import selectors
 | 
					 | 
				
			||||||
import socket
 | 
					 | 
				
			||||||
import ssl as ssl_module
 | 
					 | 
				
			||||||
import sys
 | 
					 | 
				
			||||||
import threading
 | 
					 | 
				
			||||||
import warnings
 | 
					 | 
				
			||||||
from types import TracebackType
 | 
					 | 
				
			||||||
from typing import Any, Callable, Iterable, Sequence, Tuple, cast
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from ..exceptions import InvalidHeader
 | 
					 | 
				
			||||||
from ..extensions.base import ServerExtensionFactory
 | 
					 | 
				
			||||||
from ..extensions.permessage_deflate import enable_server_permessage_deflate
 | 
					 | 
				
			||||||
from ..frames import CloseCode
 | 
					 | 
				
			||||||
from ..headers import (
 | 
					 | 
				
			||||||
    build_www_authenticate_basic,
 | 
					 | 
				
			||||||
    parse_authorization_basic,
 | 
					 | 
				
			||||||
    validate_subprotocols,
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
from ..http11 import SERVER, Request, Response
 | 
					 | 
				
			||||||
from ..protocol import CONNECTING, OPEN, Event
 | 
					 | 
				
			||||||
from ..server import ServerProtocol
 | 
					 | 
				
			||||||
from ..typing import LoggerLike, Origin, StatusLike, Subprotocol
 | 
					 | 
				
			||||||
from .connection import Connection
 | 
					 | 
				
			||||||
from .utils import Deadline
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
__all__ = ["serve", "unix_serve", "ServerConnection", "Server", "basic_auth"]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class ServerConnection(Connection):
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    :mod:`threading` implementation of a WebSocket server connection.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    :class:`ServerConnection` provides :meth:`recv` and :meth:`send` methods for
 | 
					 | 
				
			||||||
    receiving and sending messages.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    It supports iteration to receive messages::
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        for message in websocket:
 | 
					 | 
				
			||||||
            process(message)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    The iterator exits normally when the connection is closed with close code
 | 
					 | 
				
			||||||
    1000 (OK) or 1001 (going away) or without a close code. It raises a
 | 
					 | 
				
			||||||
    :exc:`~websockets.exceptions.ConnectionClosedError` when the connection is
 | 
					 | 
				
			||||||
    closed with any other code.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Args:
 | 
					 | 
				
			||||||
        socket: Socket connected to a WebSocket client.
 | 
					 | 
				
			||||||
        protocol: Sans-I/O connection.
 | 
					 | 
				
			||||||
        close_timeout: Timeout for closing the connection in seconds.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        socket: socket.socket,
 | 
					 | 
				
			||||||
        protocol: ServerProtocol,
 | 
					 | 
				
			||||||
        *,
 | 
					 | 
				
			||||||
        close_timeout: float | None = 10,
 | 
					 | 
				
			||||||
    ) -> None:
 | 
					 | 
				
			||||||
        self.protocol: ServerProtocol
 | 
					 | 
				
			||||||
        self.request_rcvd = threading.Event()
 | 
					 | 
				
			||||||
        super().__init__(
 | 
					 | 
				
			||||||
            socket,
 | 
					 | 
				
			||||||
            protocol,
 | 
					 | 
				
			||||||
            close_timeout=close_timeout,
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        self.username: str  # see basic_auth()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def respond(self, status: StatusLike, text: str) -> Response:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Create a plain text HTTP response.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        ``process_request`` and ``process_response`` may call this method to
 | 
					 | 
				
			||||||
        return an HTTP response instead of performing the WebSocket opening
 | 
					 | 
				
			||||||
        handshake.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        You can modify the response before returning it, for example by changing
 | 
					 | 
				
			||||||
        HTTP headers.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Args:
 | 
					 | 
				
			||||||
            status: HTTP status code.
 | 
					 | 
				
			||||||
            text: HTTP response body; it will be encoded to UTF-8.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Returns:
 | 
					 | 
				
			||||||
            HTTP response to send to the client.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        return self.protocol.reject(status, text)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def handshake(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        process_request: (
 | 
					 | 
				
			||||||
            Callable[
 | 
					 | 
				
			||||||
                [ServerConnection, Request],
 | 
					 | 
				
			||||||
                Response | None,
 | 
					 | 
				
			||||||
            ]
 | 
					 | 
				
			||||||
            | None
 | 
					 | 
				
			||||||
        ) = None,
 | 
					 | 
				
			||||||
        process_response: (
 | 
					 | 
				
			||||||
            Callable[
 | 
					 | 
				
			||||||
                [ServerConnection, Request, Response],
 | 
					 | 
				
			||||||
                Response | None,
 | 
					 | 
				
			||||||
            ]
 | 
					 | 
				
			||||||
            | None
 | 
					 | 
				
			||||||
        ) = None,
 | 
					 | 
				
			||||||
        server_header: str | None = SERVER,
 | 
					 | 
				
			||||||
        timeout: float | None = None,
 | 
					 | 
				
			||||||
    ) -> None:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Perform the opening handshake.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        if not self.request_rcvd.wait(timeout):
 | 
					 | 
				
			||||||
            raise TimeoutError("timed out during handshake")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if self.request is not None:
 | 
					 | 
				
			||||||
            with self.send_context(expected_state=CONNECTING):
 | 
					 | 
				
			||||||
                response = None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                if process_request is not None:
 | 
					 | 
				
			||||||
                    try:
 | 
					 | 
				
			||||||
                        response = process_request(self, self.request)
 | 
					 | 
				
			||||||
                    except Exception as exc:
 | 
					 | 
				
			||||||
                        self.protocol.handshake_exc = exc
 | 
					 | 
				
			||||||
                        response = self.protocol.reject(
 | 
					 | 
				
			||||||
                            http.HTTPStatus.INTERNAL_SERVER_ERROR,
 | 
					 | 
				
			||||||
                            (
 | 
					 | 
				
			||||||
                                "Failed to open a WebSocket connection.\n"
 | 
					 | 
				
			||||||
                                "See server log for more information.\n"
 | 
					 | 
				
			||||||
                            ),
 | 
					 | 
				
			||||||
                        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                if response is None:
 | 
					 | 
				
			||||||
                    self.response = self.protocol.accept(self.request)
 | 
					 | 
				
			||||||
                else:
 | 
					 | 
				
			||||||
                    self.response = response
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                if server_header:
 | 
					 | 
				
			||||||
                    self.response.headers["Server"] = server_header
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                response = None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                if process_response is not None:
 | 
					 | 
				
			||||||
                    try:
 | 
					 | 
				
			||||||
                        response = process_response(self, self.request, self.response)
 | 
					 | 
				
			||||||
                    except Exception as exc:
 | 
					 | 
				
			||||||
                        self.protocol.handshake_exc = exc
 | 
					 | 
				
			||||||
                        response = self.protocol.reject(
 | 
					 | 
				
			||||||
                            http.HTTPStatus.INTERNAL_SERVER_ERROR,
 | 
					 | 
				
			||||||
                            (
 | 
					 | 
				
			||||||
                                "Failed to open a WebSocket connection.\n"
 | 
					 | 
				
			||||||
                                "See server log for more information.\n"
 | 
					 | 
				
			||||||
                            ),
 | 
					 | 
				
			||||||
                        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                    if response is not None:
 | 
					 | 
				
			||||||
                        self.response = response
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                self.protocol.send_response(self.response)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # self.protocol.handshake_exc is always set when the connection is lost
 | 
					 | 
				
			||||||
        # before receiving a request, when the request cannot be parsed, when
 | 
					 | 
				
			||||||
        # the handshake encounters an error, or when process_request or
 | 
					 | 
				
			||||||
        # process_response sends a HTTP response that rejects the handshake.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if self.protocol.handshake_exc is not None:
 | 
					 | 
				
			||||||
            raise self.protocol.handshake_exc
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def process_event(self, event: Event) -> None:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Process one incoming event.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        # First event - handshake request.
 | 
					 | 
				
			||||||
        if self.request is None:
 | 
					 | 
				
			||||||
            assert isinstance(event, Request)
 | 
					 | 
				
			||||||
            self.request = event
 | 
					 | 
				
			||||||
            self.request_rcvd.set()
 | 
					 | 
				
			||||||
        # Later events - frames.
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            super().process_event(event)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def recv_events(self) -> None:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Read incoming data from the socket and process events.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        try:
 | 
					 | 
				
			||||||
            super().recv_events()
 | 
					 | 
				
			||||||
        finally:
 | 
					 | 
				
			||||||
            # If the connection is closed during the handshake, unblock it.
 | 
					 | 
				
			||||||
            self.request_rcvd.set()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class Server:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    WebSocket server returned by :func:`serve`.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    This class mirrors the API of :class:`~socketserver.BaseServer`, notably the
 | 
					 | 
				
			||||||
    :meth:`~socketserver.BaseServer.serve_forever` and
 | 
					 | 
				
			||||||
    :meth:`~socketserver.BaseServer.shutdown` methods, as well as the context
 | 
					 | 
				
			||||||
    manager protocol.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Args:
 | 
					 | 
				
			||||||
        socket: Server socket listening for new connections.
 | 
					 | 
				
			||||||
        handler: Handler for one connection. Receives the socket and address
 | 
					 | 
				
			||||||
            returned by :meth:`~socket.socket.accept`.
 | 
					 | 
				
			||||||
        logger: Logger for this server.
 | 
					 | 
				
			||||||
            It defaults to ``logging.getLogger("websockets.server")``.
 | 
					 | 
				
			||||||
            See the :doc:`logging guide <../../topics/logging>` for details.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        socket: socket.socket,
 | 
					 | 
				
			||||||
        handler: Callable[[socket.socket, Any], None],
 | 
					 | 
				
			||||||
        logger: LoggerLike | None = None,
 | 
					 | 
				
			||||||
    ) -> None:
 | 
					 | 
				
			||||||
        self.socket = socket
 | 
					 | 
				
			||||||
        self.handler = handler
 | 
					 | 
				
			||||||
        if logger is None:
 | 
					 | 
				
			||||||
            logger = logging.getLogger("websockets.server")
 | 
					 | 
				
			||||||
        self.logger = logger
 | 
					 | 
				
			||||||
        if sys.platform != "win32":
 | 
					 | 
				
			||||||
            self.shutdown_watcher, self.shutdown_notifier = os.pipe()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def serve_forever(self) -> None:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        See :meth:`socketserver.BaseServer.serve_forever`.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        This method doesn't return. Calling :meth:`shutdown` from another thread
 | 
					 | 
				
			||||||
        stops the server.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Typical use::
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            with serve(...) as server:
 | 
					 | 
				
			||||||
                server.serve_forever()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        poller = selectors.DefaultSelector()
 | 
					 | 
				
			||||||
        try:
 | 
					 | 
				
			||||||
            poller.register(self.socket, selectors.EVENT_READ)
 | 
					 | 
				
			||||||
        except ValueError:  # pragma: no cover
 | 
					 | 
				
			||||||
            # If shutdown() is called before poller.register(),
 | 
					 | 
				
			||||||
            # the socket is closed and poller.register() raises
 | 
					 | 
				
			||||||
            # ValueError: Invalid file descriptor: -1
 | 
					 | 
				
			||||||
            return
 | 
					 | 
				
			||||||
        if sys.platform != "win32":
 | 
					 | 
				
			||||||
            poller.register(self.shutdown_watcher, selectors.EVENT_READ)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        while True:
 | 
					 | 
				
			||||||
            poller.select()
 | 
					 | 
				
			||||||
            try:
 | 
					 | 
				
			||||||
                # If the socket is closed, this will raise an exception and exit
 | 
					 | 
				
			||||||
                # the loop. So we don't need to check the return value of select().
 | 
					 | 
				
			||||||
                sock, addr = self.socket.accept()
 | 
					 | 
				
			||||||
            except OSError:
 | 
					 | 
				
			||||||
                break
 | 
					 | 
				
			||||||
            # Since there isn't a mechanism for tracking connections and waiting
 | 
					 | 
				
			||||||
            # for them to terminate, we cannot use daemon threads, or else all
 | 
					 | 
				
			||||||
            # connections would be terminate brutally when closing the server.
 | 
					 | 
				
			||||||
            thread = threading.Thread(target=self.handler, args=(sock, addr))
 | 
					 | 
				
			||||||
            thread.start()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def shutdown(self) -> None:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        See :meth:`socketserver.BaseServer.shutdown`.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        self.socket.close()
 | 
					 | 
				
			||||||
        if sys.platform != "win32":
 | 
					 | 
				
			||||||
            os.write(self.shutdown_notifier, b"x")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def fileno(self) -> int:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        See :meth:`socketserver.BaseServer.fileno`.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        return self.socket.fileno()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __enter__(self) -> Server:
 | 
					 | 
				
			||||||
        return self
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __exit__(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        exc_type: type[BaseException] | None,
 | 
					 | 
				
			||||||
        exc_value: BaseException | None,
 | 
					 | 
				
			||||||
        traceback: TracebackType | None,
 | 
					 | 
				
			||||||
    ) -> None:
 | 
					 | 
				
			||||||
        self.shutdown()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def __getattr__(name: str) -> Any:
 | 
					 | 
				
			||||||
    if name == "WebSocketServer":
 | 
					 | 
				
			||||||
        warnings.warn(  # deprecated in 13.0 - 2024-08-20
 | 
					 | 
				
			||||||
            "WebSocketServer was renamed to Server",
 | 
					 | 
				
			||||||
            DeprecationWarning,
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        return Server
 | 
					 | 
				
			||||||
    raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def serve(
 | 
					 | 
				
			||||||
    handler: Callable[[ServerConnection], None],
 | 
					 | 
				
			||||||
    host: str | None = None,
 | 
					 | 
				
			||||||
    port: int | None = None,
 | 
					 | 
				
			||||||
    *,
 | 
					 | 
				
			||||||
    # TCP/TLS
 | 
					 | 
				
			||||||
    sock: socket.socket | None = None,
 | 
					 | 
				
			||||||
    ssl: ssl_module.SSLContext | None = None,
 | 
					 | 
				
			||||||
    # WebSocket
 | 
					 | 
				
			||||||
    origins: Sequence[Origin | None] | None = None,
 | 
					 | 
				
			||||||
    extensions: Sequence[ServerExtensionFactory] | None = None,
 | 
					 | 
				
			||||||
    subprotocols: Sequence[Subprotocol] | None = None,
 | 
					 | 
				
			||||||
    select_subprotocol: (
 | 
					 | 
				
			||||||
        Callable[
 | 
					 | 
				
			||||||
            [ServerConnection, Sequence[Subprotocol]],
 | 
					 | 
				
			||||||
            Subprotocol | None,
 | 
					 | 
				
			||||||
        ]
 | 
					 | 
				
			||||||
        | None
 | 
					 | 
				
			||||||
    ) = None,
 | 
					 | 
				
			||||||
    process_request: (
 | 
					 | 
				
			||||||
        Callable[
 | 
					 | 
				
			||||||
            [ServerConnection, Request],
 | 
					 | 
				
			||||||
            Response | None,
 | 
					 | 
				
			||||||
        ]
 | 
					 | 
				
			||||||
        | None
 | 
					 | 
				
			||||||
    ) = None,
 | 
					 | 
				
			||||||
    process_response: (
 | 
					 | 
				
			||||||
        Callable[
 | 
					 | 
				
			||||||
            [ServerConnection, Request, Response],
 | 
					 | 
				
			||||||
            Response | None,
 | 
					 | 
				
			||||||
        ]
 | 
					 | 
				
			||||||
        | None
 | 
					 | 
				
			||||||
    ) = None,
 | 
					 | 
				
			||||||
    server_header: str | None = SERVER,
 | 
					 | 
				
			||||||
    compression: str | None = "deflate",
 | 
					 | 
				
			||||||
    # Timeouts
 | 
					 | 
				
			||||||
    open_timeout: float | None = 10,
 | 
					 | 
				
			||||||
    close_timeout: float | None = 10,
 | 
					 | 
				
			||||||
    # Limits
 | 
					 | 
				
			||||||
    max_size: int | None = 2**20,
 | 
					 | 
				
			||||||
    # Logging
 | 
					 | 
				
			||||||
    logger: LoggerLike | None = None,
 | 
					 | 
				
			||||||
    # Escape hatch for advanced customization
 | 
					 | 
				
			||||||
    create_connection: type[ServerConnection] | None = None,
 | 
					 | 
				
			||||||
    **kwargs: Any,
 | 
					 | 
				
			||||||
) -> Server:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Create a WebSocket server listening on ``host`` and ``port``.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Whenever a client connects, the server creates a :class:`ServerConnection`,
 | 
					 | 
				
			||||||
    performs the opening handshake, and delegates to the ``handler``.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    The handler receives the :class:`ServerConnection` instance, which you can
 | 
					 | 
				
			||||||
    use to send and receive messages.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Once the handler completes, either normally or with an exception, the server
 | 
					 | 
				
			||||||
    performs the closing handshake and closes the connection.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    This function returns a :class:`Server` whose API mirrors
 | 
					 | 
				
			||||||
    :class:`~socketserver.BaseServer`. Treat it as a context manager to ensure
 | 
					 | 
				
			||||||
    that it will be closed and call :meth:`~Server.serve_forever` to serve
 | 
					 | 
				
			||||||
    requests::
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        from websockets.sync.server import serve
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        def handler(websocket):
 | 
					 | 
				
			||||||
            ...
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        with serve(handler, ...) as server:
 | 
					 | 
				
			||||||
            server.serve_forever()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Args:
 | 
					 | 
				
			||||||
        handler: Connection handler. It receives the WebSocket connection,
 | 
					 | 
				
			||||||
            which is a :class:`ServerConnection`, in argument.
 | 
					 | 
				
			||||||
        host: Network interfaces the server binds to.
 | 
					 | 
				
			||||||
            See :func:`~socket.create_server` for details.
 | 
					 | 
				
			||||||
        port: TCP port the server listens on.
 | 
					 | 
				
			||||||
            See :func:`~socket.create_server` for details.
 | 
					 | 
				
			||||||
        sock: Preexisting TCP socket. ``sock`` replaces ``host`` and ``port``.
 | 
					 | 
				
			||||||
            You may call :func:`socket.create_server` to create a suitable TCP
 | 
					 | 
				
			||||||
            socket.
 | 
					 | 
				
			||||||
        ssl: Configuration for enabling TLS on the connection.
 | 
					 | 
				
			||||||
        origins: Acceptable values of the ``Origin`` header, for defending
 | 
					 | 
				
			||||||
            against Cross-Site WebSocket Hijacking attacks. Include :obj:`None`
 | 
					 | 
				
			||||||
            in the list if the lack of an origin is acceptable.
 | 
					 | 
				
			||||||
        extensions: List of supported extensions, in order in which they
 | 
					 | 
				
			||||||
            should be negotiated and run.
 | 
					 | 
				
			||||||
        subprotocols: List of supported subprotocols, in order of decreasing
 | 
					 | 
				
			||||||
            preference.
 | 
					 | 
				
			||||||
        select_subprotocol: Callback for selecting a subprotocol among
 | 
					 | 
				
			||||||
            those supported by the client and the server. It receives a
 | 
					 | 
				
			||||||
            :class:`ServerConnection` (not a
 | 
					 | 
				
			||||||
            :class:`~websockets.server.ServerProtocol`!) instance and a list of
 | 
					 | 
				
			||||||
            subprotocols offered by the client. Other than the first argument,
 | 
					 | 
				
			||||||
            it has the same behavior as the
 | 
					 | 
				
			||||||
            :meth:`ServerProtocol.select_subprotocol
 | 
					 | 
				
			||||||
            <websockets.server.ServerProtocol.select_subprotocol>` method.
 | 
					 | 
				
			||||||
        process_request: Intercept the request during the opening handshake.
 | 
					 | 
				
			||||||
            Return an HTTP response to force the response. Return :obj:`None` to
 | 
					 | 
				
			||||||
            continue normally. When you force an HTTP 101 Continue response, the
 | 
					 | 
				
			||||||
            handshake is successful. Else, the connection is aborted.
 | 
					 | 
				
			||||||
        process_response: Intercept the response during the opening handshake.
 | 
					 | 
				
			||||||
            Modify the response or return a new HTTP response to force the
 | 
					 | 
				
			||||||
            response. Return :obj:`None` to continue normally. When you force an
 | 
					 | 
				
			||||||
            HTTP 101 Continue response, the handshake is successful. Else, the
 | 
					 | 
				
			||||||
            connection is aborted.
 | 
					 | 
				
			||||||
        server_header: Value of  the ``Server`` response header.
 | 
					 | 
				
			||||||
            It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to
 | 
					 | 
				
			||||||
            :obj:`None` removes the header.
 | 
					 | 
				
			||||||
        compression: The "permessage-deflate" extension is enabled by default.
 | 
					 | 
				
			||||||
            Set ``compression`` to :obj:`None` to disable it. See the
 | 
					 | 
				
			||||||
            :doc:`compression guide <../../topics/compression>` for details.
 | 
					 | 
				
			||||||
        open_timeout: Timeout for opening connections in seconds.
 | 
					 | 
				
			||||||
            :obj:`None` disables the timeout.
 | 
					 | 
				
			||||||
        close_timeout: Timeout for closing connections in seconds.
 | 
					 | 
				
			||||||
            :obj:`None` disables the timeout.
 | 
					 | 
				
			||||||
        max_size: Maximum size of incoming messages in bytes.
 | 
					 | 
				
			||||||
            :obj:`None` disables the limit.
 | 
					 | 
				
			||||||
        logger: Logger for this server.
 | 
					 | 
				
			||||||
            It defaults to ``logging.getLogger("websockets.server")``. See the
 | 
					 | 
				
			||||||
            :doc:`logging guide <../../topics/logging>` for details.
 | 
					 | 
				
			||||||
        create_connection: Factory for the :class:`ServerConnection` managing
 | 
					 | 
				
			||||||
            the connection. Set it to a wrapper or a subclass to customize
 | 
					 | 
				
			||||||
            connection handling.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Any other keyword arguments are passed to :func:`~socket.create_server`.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Process parameters
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Backwards compatibility: ssl used to be called ssl_context.
 | 
					 | 
				
			||||||
    if ssl is None and "ssl_context" in kwargs:
 | 
					 | 
				
			||||||
        ssl = kwargs.pop("ssl_context")
 | 
					 | 
				
			||||||
        warnings.warn(  # deprecated in 13.0 - 2024-08-20
 | 
					 | 
				
			||||||
            "ssl_context was renamed to ssl",
 | 
					 | 
				
			||||||
            DeprecationWarning,
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if subprotocols is not None:
 | 
					 | 
				
			||||||
        validate_subprotocols(subprotocols)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if compression == "deflate":
 | 
					 | 
				
			||||||
        extensions = enable_server_permessage_deflate(extensions)
 | 
					 | 
				
			||||||
    elif compression is not None:
 | 
					 | 
				
			||||||
        raise ValueError(f"unsupported compression: {compression}")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if create_connection is None:
 | 
					 | 
				
			||||||
        create_connection = ServerConnection
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Bind socket and listen
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Private APIs for unix_connect()
 | 
					 | 
				
			||||||
    unix: bool = kwargs.pop("unix", False)
 | 
					 | 
				
			||||||
    path: str | None = kwargs.pop("path", None)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if sock is None:
 | 
					 | 
				
			||||||
        if unix:
 | 
					 | 
				
			||||||
            if path is None:
 | 
					 | 
				
			||||||
                raise TypeError("missing path argument")
 | 
					 | 
				
			||||||
            kwargs.setdefault("family", socket.AF_UNIX)
 | 
					 | 
				
			||||||
            sock = socket.create_server(path, **kwargs)
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            sock = socket.create_server((host, port), **kwargs)
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        if path is not None:
 | 
					 | 
				
			||||||
            raise TypeError("path and sock arguments are incompatible")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Initialize TLS wrapper
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if ssl is not None:
 | 
					 | 
				
			||||||
        sock = ssl.wrap_socket(
 | 
					 | 
				
			||||||
            sock,
 | 
					 | 
				
			||||||
            server_side=True,
 | 
					 | 
				
			||||||
            # Delay TLS handshake until after we set a timeout on the socket.
 | 
					 | 
				
			||||||
            do_handshake_on_connect=False,
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Define request handler
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def conn_handler(sock: socket.socket, addr: Any) -> None:
 | 
					 | 
				
			||||||
        # Calculate timeouts on the TLS and WebSocket handshakes.
 | 
					 | 
				
			||||||
        # The TLS timeout must be set on the socket, then removed
 | 
					 | 
				
			||||||
        # to avoid conflicting with the WebSocket timeout in handshake().
 | 
					 | 
				
			||||||
        deadline = Deadline(open_timeout)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        try:
 | 
					 | 
				
			||||||
            # Disable Nagle algorithm
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            if not unix:
 | 
					 | 
				
			||||||
                sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, True)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            # Perform TLS handshake
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            if ssl is not None:
 | 
					 | 
				
			||||||
                sock.settimeout(deadline.timeout())
 | 
					 | 
				
			||||||
                # mypy cannot figure this out
 | 
					 | 
				
			||||||
                assert isinstance(sock, ssl_module.SSLSocket)
 | 
					 | 
				
			||||||
                sock.do_handshake()
 | 
					 | 
				
			||||||
                sock.settimeout(None)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            # Create a closure to give select_subprotocol access to connection.
 | 
					 | 
				
			||||||
            protocol_select_subprotocol: (
 | 
					 | 
				
			||||||
                Callable[
 | 
					 | 
				
			||||||
                    [ServerProtocol, Sequence[Subprotocol]],
 | 
					 | 
				
			||||||
                    Subprotocol | None,
 | 
					 | 
				
			||||||
                ]
 | 
					 | 
				
			||||||
                | None
 | 
					 | 
				
			||||||
            ) = None
 | 
					 | 
				
			||||||
            if select_subprotocol is not None:
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                def protocol_select_subprotocol(
 | 
					 | 
				
			||||||
                    protocol: ServerProtocol,
 | 
					 | 
				
			||||||
                    subprotocols: Sequence[Subprotocol],
 | 
					 | 
				
			||||||
                ) -> Subprotocol | None:
 | 
					 | 
				
			||||||
                    # mypy doesn't know that select_subprotocol is immutable.
 | 
					 | 
				
			||||||
                    assert select_subprotocol is not None
 | 
					 | 
				
			||||||
                    # Ensure this function is only used in the intended context.
 | 
					 | 
				
			||||||
                    assert protocol is connection.protocol
 | 
					 | 
				
			||||||
                    return select_subprotocol(connection, subprotocols)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            # Initialize WebSocket protocol
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            protocol = ServerProtocol(
 | 
					 | 
				
			||||||
                origins=origins,
 | 
					 | 
				
			||||||
                extensions=extensions,
 | 
					 | 
				
			||||||
                subprotocols=subprotocols,
 | 
					 | 
				
			||||||
                select_subprotocol=protocol_select_subprotocol,
 | 
					 | 
				
			||||||
                max_size=max_size,
 | 
					 | 
				
			||||||
                logger=logger,
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            # Initialize WebSocket connection
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            assert create_connection is not None  # help mypy
 | 
					 | 
				
			||||||
            connection = create_connection(
 | 
					 | 
				
			||||||
                sock,
 | 
					 | 
				
			||||||
                protocol,
 | 
					 | 
				
			||||||
                close_timeout=close_timeout,
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
        except Exception:
 | 
					 | 
				
			||||||
            sock.close()
 | 
					 | 
				
			||||||
            return
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        try:
 | 
					 | 
				
			||||||
            try:
 | 
					 | 
				
			||||||
                connection.handshake(
 | 
					 | 
				
			||||||
                    process_request,
 | 
					 | 
				
			||||||
                    process_response,
 | 
					 | 
				
			||||||
                    server_header,
 | 
					 | 
				
			||||||
                    deadline.timeout(),
 | 
					 | 
				
			||||||
                )
 | 
					 | 
				
			||||||
            except TimeoutError:
 | 
					 | 
				
			||||||
                connection.close_socket()
 | 
					 | 
				
			||||||
                connection.recv_events_thread.join()
 | 
					 | 
				
			||||||
                return
 | 
					 | 
				
			||||||
            except Exception:
 | 
					 | 
				
			||||||
                connection.logger.error("opening handshake failed", exc_info=True)
 | 
					 | 
				
			||||||
                connection.close_socket()
 | 
					 | 
				
			||||||
                connection.recv_events_thread.join()
 | 
					 | 
				
			||||||
                return
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            assert connection.protocol.state is OPEN
 | 
					 | 
				
			||||||
            try:
 | 
					 | 
				
			||||||
                handler(connection)
 | 
					 | 
				
			||||||
            except Exception:
 | 
					 | 
				
			||||||
                connection.logger.error("connection handler failed", exc_info=True)
 | 
					 | 
				
			||||||
                connection.close(CloseCode.INTERNAL_ERROR)
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                connection.close()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        except Exception:  # pragma: no cover
 | 
					 | 
				
			||||||
            # Don't leak sockets on unexpected errors.
 | 
					 | 
				
			||||||
            sock.close()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Initialize server
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    return Server(sock, conn_handler, logger)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def unix_serve(
 | 
					 | 
				
			||||||
    handler: Callable[[ServerConnection], None],
 | 
					 | 
				
			||||||
    path: str | None = None,
 | 
					 | 
				
			||||||
    **kwargs: Any,
 | 
					 | 
				
			||||||
) -> Server:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Create a WebSocket server listening on a Unix socket.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    This function accepts the same keyword arguments as :func:`serve`.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    It's only available on Unix.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    It's useful for deploying a server behind a reverse proxy such as nginx.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Args:
 | 
					 | 
				
			||||||
        handler: Connection handler. It receives the WebSocket connection,
 | 
					 | 
				
			||||||
            which is a :class:`ServerConnection`, in argument.
 | 
					 | 
				
			||||||
        path: File system path to the Unix socket.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    return serve(handler, unix=True, path=path, **kwargs)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def is_credentials(credentials: Any) -> bool:
 | 
					 | 
				
			||||||
    try:
 | 
					 | 
				
			||||||
        username, password = credentials
 | 
					 | 
				
			||||||
    except (TypeError, ValueError):
 | 
					 | 
				
			||||||
        return False
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        return isinstance(username, str) and isinstance(password, str)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def basic_auth(
 | 
					 | 
				
			||||||
    realm: str = "",
 | 
					 | 
				
			||||||
    credentials: tuple[str, str] | Iterable[tuple[str, str]] | None = None,
 | 
					 | 
				
			||||||
    check_credentials: Callable[[str, str], bool] | None = None,
 | 
					 | 
				
			||||||
) -> Callable[[ServerConnection, Request], Response | None]:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Factory for ``process_request`` to enforce HTTP Basic Authentication.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    :func:`basic_auth` is designed to integrate with :func:`serve` as follows::
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        from websockets.sync.server import basic_auth, serve
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        with serve(
 | 
					 | 
				
			||||||
            ...,
 | 
					 | 
				
			||||||
            process_request=basic_auth(
 | 
					 | 
				
			||||||
                realm="my dev server",
 | 
					 | 
				
			||||||
                credentials=("hello", "iloveyou"),
 | 
					 | 
				
			||||||
            ),
 | 
					 | 
				
			||||||
        ):
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    If authentication succeeds, the connection's ``username`` attribute is set.
 | 
					 | 
				
			||||||
    If it fails, the server responds with an HTTP 401 Unauthorized status.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    One of ``credentials`` or ``check_credentials`` must be provided; not both.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Args:
 | 
					 | 
				
			||||||
        realm: Scope of protection. It should contain only ASCII characters
 | 
					 | 
				
			||||||
            because the encoding of non-ASCII characters is undefined. Refer to
 | 
					 | 
				
			||||||
            section 2.2 of :rfc:`7235` for details.
 | 
					 | 
				
			||||||
        credentials: Hard coded authorized credentials. It can be a
 | 
					 | 
				
			||||||
            ``(username, password)`` pair or a list of such pairs.
 | 
					 | 
				
			||||||
        check_credentials: Function that verifies credentials.
 | 
					 | 
				
			||||||
            It receives ``username`` and ``password`` arguments and returns
 | 
					 | 
				
			||||||
            whether they're valid.
 | 
					 | 
				
			||||||
    Raises:
 | 
					 | 
				
			||||||
        TypeError: If ``credentials`` or ``check_credentials`` is wrong.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    if (credentials is None) == (check_credentials is None):
 | 
					 | 
				
			||||||
        raise TypeError("provide either credentials or check_credentials")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if credentials is not None:
 | 
					 | 
				
			||||||
        if is_credentials(credentials):
 | 
					 | 
				
			||||||
            credentials_list = [cast(Tuple[str, str], credentials)]
 | 
					 | 
				
			||||||
        elif isinstance(credentials, Iterable):
 | 
					 | 
				
			||||||
            credentials_list = list(cast(Iterable[Tuple[str, str]], credentials))
 | 
					 | 
				
			||||||
            if not all(is_credentials(item) for item in credentials_list):
 | 
					 | 
				
			||||||
                raise TypeError(f"invalid credentials argument: {credentials}")
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            raise TypeError(f"invalid credentials argument: {credentials}")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        credentials_dict = dict(credentials_list)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        def check_credentials(username: str, password: str) -> bool:
 | 
					 | 
				
			||||||
            try:
 | 
					 | 
				
			||||||
                expected_password = credentials_dict[username]
 | 
					 | 
				
			||||||
            except KeyError:
 | 
					 | 
				
			||||||
                return False
 | 
					 | 
				
			||||||
            return hmac.compare_digest(expected_password, password)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    assert check_credentials is not None  # help mypy
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def process_request(
 | 
					 | 
				
			||||||
        connection: ServerConnection,
 | 
					 | 
				
			||||||
        request: Request,
 | 
					 | 
				
			||||||
    ) -> Response | None:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Perform HTTP Basic Authentication.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        If it succeeds, set the connection's ``username`` attribute and return
 | 
					 | 
				
			||||||
        :obj:`None`. If it fails, return an HTTP 401 Unauthorized responss.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        try:
 | 
					 | 
				
			||||||
            authorization = request.headers["Authorization"]
 | 
					 | 
				
			||||||
        except KeyError:
 | 
					 | 
				
			||||||
            response = connection.respond(
 | 
					 | 
				
			||||||
                http.HTTPStatus.UNAUTHORIZED,
 | 
					 | 
				
			||||||
                "Missing credentials\n",
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
            response.headers["WWW-Authenticate"] = build_www_authenticate_basic(realm)
 | 
					 | 
				
			||||||
            return response
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        try:
 | 
					 | 
				
			||||||
            username, password = parse_authorization_basic(authorization)
 | 
					 | 
				
			||||||
        except InvalidHeader:
 | 
					 | 
				
			||||||
            response = connection.respond(
 | 
					 | 
				
			||||||
                http.HTTPStatus.UNAUTHORIZED,
 | 
					 | 
				
			||||||
                "Unsupported credentials\n",
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
            response.headers["WWW-Authenticate"] = build_www_authenticate_basic(realm)
 | 
					 | 
				
			||||||
            return response
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if not check_credentials(username, password):
 | 
					 | 
				
			||||||
            response = connection.respond(
 | 
					 | 
				
			||||||
                http.HTTPStatus.UNAUTHORIZED,
 | 
					 | 
				
			||||||
                "Invalid credentials\n",
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
            response.headers["WWW-Authenticate"] = build_www_authenticate_basic(realm)
 | 
					 | 
				
			||||||
            return response
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        connection.username = username
 | 
					 | 
				
			||||||
        return None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    return process_request
 | 
					 | 
				
			||||||
@@ -1,45 +0,0 @@
 | 
				
			|||||||
from __future__ import annotations
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import time
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
__all__ = ["Deadline"]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class Deadline:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Manage timeouts across multiple steps.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Args:
 | 
					 | 
				
			||||||
        timeout: Time available in seconds or :obj:`None` if there is no limit.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(self, timeout: float | None) -> None:
 | 
					 | 
				
			||||||
        self.deadline: float | None
 | 
					 | 
				
			||||||
        if timeout is None:
 | 
					 | 
				
			||||||
            self.deadline = None
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            self.deadline = time.monotonic() + timeout
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def timeout(self, *, raise_if_elapsed: bool = True) -> float | None:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Calculate a timeout from a deadline.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Args:
 | 
					 | 
				
			||||||
            raise_if_elapsed: Whether to raise :exc:`TimeoutError`
 | 
					 | 
				
			||||||
                if the deadline lapsed.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Raises:
 | 
					 | 
				
			||||||
            TimeoutError: If the deadline lapsed.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Returns:
 | 
					 | 
				
			||||||
            Time left in seconds or :obj:`None` if there is no limit.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        if self.deadline is None:
 | 
					 | 
				
			||||||
            return None
 | 
					 | 
				
			||||||
        timeout = self.deadline - time.monotonic()
 | 
					 | 
				
			||||||
        if raise_if_elapsed and timeout <= 0:
 | 
					 | 
				
			||||||
            raise TimeoutError("timed out")
 | 
					 | 
				
			||||||
        return timeout
 | 
					 | 
				
			||||||
@@ -1,77 +0,0 @@
 | 
				
			|||||||
from __future__ import annotations
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import http
 | 
					 | 
				
			||||||
import logging
 | 
					 | 
				
			||||||
import typing
 | 
					 | 
				
			||||||
from typing import Any, List, NewType, Optional, Tuple, Union
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
__all__ = [
 | 
					 | 
				
			||||||
    "Data",
 | 
					 | 
				
			||||||
    "LoggerLike",
 | 
					 | 
				
			||||||
    "StatusLike",
 | 
					 | 
				
			||||||
    "Origin",
 | 
					 | 
				
			||||||
    "Subprotocol",
 | 
					 | 
				
			||||||
    "ExtensionName",
 | 
					 | 
				
			||||||
    "ExtensionParameter",
 | 
					 | 
				
			||||||
]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# Public types used in the signature of public APIs
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# Change to str | bytes when dropping Python < 3.10.
 | 
					 | 
				
			||||||
Data = Union[str, bytes]
 | 
					 | 
				
			||||||
"""Types supported in a WebSocket message:
 | 
					 | 
				
			||||||
:class:`str` for a Text_ frame, :class:`bytes` for a Binary_.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
.. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6
 | 
					 | 
				
			||||||
.. _Binary : https://datatracker.ietf.org/doc/html/rfc6455#section-5.6
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
"""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# Change to logging.Logger | ... when dropping Python < 3.10.
 | 
					 | 
				
			||||||
if typing.TYPE_CHECKING:
 | 
					 | 
				
			||||||
    LoggerLike = Union[logging.Logger, logging.LoggerAdapter[Any]]
 | 
					 | 
				
			||||||
    """Types accepted where a :class:`~logging.Logger` is expected."""
 | 
					 | 
				
			||||||
else:  # remove this branch when dropping support for Python < 3.11
 | 
					 | 
				
			||||||
    LoggerLike = Union[logging.Logger, logging.LoggerAdapter]
 | 
					 | 
				
			||||||
    """Types accepted where a :class:`~logging.Logger` is expected."""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# Change to http.HTTPStatus | int when dropping Python < 3.10.
 | 
					 | 
				
			||||||
StatusLike = Union[http.HTTPStatus, int]
 | 
					 | 
				
			||||||
"""
 | 
					 | 
				
			||||||
Types accepted where an :class:`~http.HTTPStatus` is expected."""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
Origin = NewType("Origin", str)
 | 
					 | 
				
			||||||
"""Value of a ``Origin`` header."""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
Subprotocol = NewType("Subprotocol", str)
 | 
					 | 
				
			||||||
"""Subprotocol in a ``Sec-WebSocket-Protocol`` header."""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
ExtensionName = NewType("ExtensionName", str)
 | 
					 | 
				
			||||||
"""Name of a WebSocket extension."""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# Change to tuple[str, Optional[str]] when dropping Python < 3.9.
 | 
					 | 
				
			||||||
# Change to tuple[str, str | None] when dropping Python < 3.10.
 | 
					 | 
				
			||||||
ExtensionParameter = Tuple[str, Optional[str]]
 | 
					 | 
				
			||||||
"""Parameter of a WebSocket extension."""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# Private types
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# Change to tuple[.., list[...]] when dropping Python < 3.9.
 | 
					 | 
				
			||||||
ExtensionHeader = Tuple[ExtensionName, List[ExtensionParameter]]
 | 
					 | 
				
			||||||
"""Extension in a ``Sec-WebSocket-Extensions`` header."""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
ConnectionOption = NewType("ConnectionOption", str)
 | 
					 | 
				
			||||||
"""Connection option in a ``Connection`` header."""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
UpgradeProtocol = NewType("UpgradeProtocol", str)
 | 
					 | 
				
			||||||
"""Upgrade protocol in an ``Upgrade`` header."""
 | 
					 | 
				
			||||||
@@ -1,107 +0,0 @@
 | 
				
			|||||||
from __future__ import annotations
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import dataclasses
 | 
					 | 
				
			||||||
import urllib.parse
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from .exceptions import InvalidURI
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
__all__ = ["parse_uri", "WebSocketURI"]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@dataclasses.dataclass
 | 
					 | 
				
			||||||
class WebSocketURI:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    WebSocket URI.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Attributes:
 | 
					 | 
				
			||||||
        secure: :obj:`True` for a ``wss`` URI, :obj:`False` for a ``ws`` URI.
 | 
					 | 
				
			||||||
        host: Normalized to lower case.
 | 
					 | 
				
			||||||
        port: Always set even if it's the default.
 | 
					 | 
				
			||||||
        path: May be empty.
 | 
					 | 
				
			||||||
        query: May be empty if the URI doesn't include a query component.
 | 
					 | 
				
			||||||
        username: Available when the URI contains `User Information`_.
 | 
					 | 
				
			||||||
        password: Available when the URI contains `User Information`_.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    .. _User Information: https://datatracker.ietf.org/doc/html/rfc3986#section-3.2.1
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    secure: bool
 | 
					 | 
				
			||||||
    host: str
 | 
					 | 
				
			||||||
    port: int
 | 
					 | 
				
			||||||
    path: str
 | 
					 | 
				
			||||||
    query: str
 | 
					 | 
				
			||||||
    username: str | None = None
 | 
					 | 
				
			||||||
    password: str | None = None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @property
 | 
					 | 
				
			||||||
    def resource_name(self) -> str:
 | 
					 | 
				
			||||||
        if self.path:
 | 
					 | 
				
			||||||
            resource_name = self.path
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            resource_name = "/"
 | 
					 | 
				
			||||||
        if self.query:
 | 
					 | 
				
			||||||
            resource_name += "?" + self.query
 | 
					 | 
				
			||||||
        return resource_name
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @property
 | 
					 | 
				
			||||||
    def user_info(self) -> tuple[str, str] | None:
 | 
					 | 
				
			||||||
        if self.username is None:
 | 
					 | 
				
			||||||
            return None
 | 
					 | 
				
			||||||
        assert self.password is not None
 | 
					 | 
				
			||||||
        return (self.username, self.password)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# All characters from the gen-delims and sub-delims sets in RFC 3987.
 | 
					 | 
				
			||||||
DELIMS = ":/?#[]@!$&'()*+,;="
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def parse_uri(uri: str) -> WebSocketURI:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Parse and validate a WebSocket URI.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Args:
 | 
					 | 
				
			||||||
        uri: WebSocket URI.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Returns:
 | 
					 | 
				
			||||||
        Parsed WebSocket URI.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Raises:
 | 
					 | 
				
			||||||
        InvalidURI: If ``uri`` isn't a valid WebSocket URI.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    parsed = urllib.parse.urlparse(uri)
 | 
					 | 
				
			||||||
    if parsed.scheme not in ["ws", "wss"]:
 | 
					 | 
				
			||||||
        raise InvalidURI(uri, "scheme isn't ws or wss")
 | 
					 | 
				
			||||||
    if parsed.hostname is None:
 | 
					 | 
				
			||||||
        raise InvalidURI(uri, "hostname isn't provided")
 | 
					 | 
				
			||||||
    if parsed.fragment != "":
 | 
					 | 
				
			||||||
        raise InvalidURI(uri, "fragment identifier is meaningless")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    secure = parsed.scheme == "wss"
 | 
					 | 
				
			||||||
    host = parsed.hostname
 | 
					 | 
				
			||||||
    port = parsed.port or (443 if secure else 80)
 | 
					 | 
				
			||||||
    path = parsed.path
 | 
					 | 
				
			||||||
    query = parsed.query
 | 
					 | 
				
			||||||
    username = parsed.username
 | 
					 | 
				
			||||||
    password = parsed.password
 | 
					 | 
				
			||||||
    # urllib.parse.urlparse accepts URLs with a username but without a
 | 
					 | 
				
			||||||
    # password. This doesn't make sense for HTTP Basic Auth credentials.
 | 
					 | 
				
			||||||
    if username is not None and password is None:
 | 
					 | 
				
			||||||
        raise InvalidURI(uri, "username provided without password")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    try:
 | 
					 | 
				
			||||||
        uri.encode("ascii")
 | 
					 | 
				
			||||||
    except UnicodeEncodeError:
 | 
					 | 
				
			||||||
        # Input contains non-ASCII characters.
 | 
					 | 
				
			||||||
        # It must be an IRI. Convert it to a URI.
 | 
					 | 
				
			||||||
        host = host.encode("idna").decode()
 | 
					 | 
				
			||||||
        path = urllib.parse.quote(path, safe=DELIMS)
 | 
					 | 
				
			||||||
        query = urllib.parse.quote(query, safe=DELIMS)
 | 
					 | 
				
			||||||
        if username is not None:
 | 
					 | 
				
			||||||
            assert password is not None
 | 
					 | 
				
			||||||
            username = urllib.parse.quote(username, safe=DELIMS)
 | 
					 | 
				
			||||||
            password = urllib.parse.quote(password, safe=DELIMS)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    return WebSocketURI(secure, host, port, path, query, username, password)
 | 
					 | 
				
			||||||
@@ -1,51 +0,0 @@
 | 
				
			|||||||
from __future__ import annotations
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import base64
 | 
					 | 
				
			||||||
import hashlib
 | 
					 | 
				
			||||||
import secrets
 | 
					 | 
				
			||||||
import sys
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
__all__ = ["accept_key", "apply_mask"]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def generate_key() -> str:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Generate a random key for the Sec-WebSocket-Key header.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    key = secrets.token_bytes(16)
 | 
					 | 
				
			||||||
    return base64.b64encode(key).decode()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def accept_key(key: str) -> str:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Compute the value of the Sec-WebSocket-Accept header.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Args:
 | 
					 | 
				
			||||||
        key: Value of the Sec-WebSocket-Key header.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    sha1 = hashlib.sha1((key + GUID).encode()).digest()
 | 
					 | 
				
			||||||
    return base64.b64encode(sha1).decode()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def apply_mask(data: bytes, mask: bytes) -> bytes:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Apply masking to the data of a WebSocket message.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Args:
 | 
					 | 
				
			||||||
        data: Data to mask.
 | 
					 | 
				
			||||||
        mask: 4-bytes mask.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    if len(mask) != 4:
 | 
					 | 
				
			||||||
        raise ValueError("mask must contain 4 bytes")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    data_int = int.from_bytes(data, sys.byteorder)
 | 
					 | 
				
			||||||
    mask_repeated = mask * (len(data) // 4) + mask[: len(data) % 4]
 | 
					 | 
				
			||||||
    mask_int = int.from_bytes(mask_repeated, sys.byteorder)
 | 
					 | 
				
			||||||
    return (data_int ^ mask_int).to_bytes(len(data), sys.byteorder)
 | 
					 | 
				
			||||||
@@ -1,92 +0,0 @@
 | 
				
			|||||||
from __future__ import annotations
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import importlib.metadata
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
__all__ = ["tag", "version", "commit"]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# ========= =========== ===================
 | 
					 | 
				
			||||||
#           release     development
 | 
					 | 
				
			||||||
# ========= =========== ===================
 | 
					 | 
				
			||||||
# tag       X.Y         X.Y (upcoming)
 | 
					 | 
				
			||||||
# version   X.Y         X.Y.dev1+g5678cde
 | 
					 | 
				
			||||||
# commit    X.Y         5678cde
 | 
					 | 
				
			||||||
# ========= =========== ===================
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# When tagging a release, set `released = True`.
 | 
					 | 
				
			||||||
# After tagging a release, set `released = False` and increment `tag`.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
released = False
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
tag = version = commit = "13.1"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
if not released:  # pragma: no cover
 | 
					 | 
				
			||||||
    import pathlib
 | 
					 | 
				
			||||||
    import re
 | 
					 | 
				
			||||||
    import subprocess
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def get_version(tag: str) -> str:
 | 
					 | 
				
			||||||
        # Since setup.py executes the contents of src/websockets/version.py,
 | 
					 | 
				
			||||||
        # __file__ can point to either of these two files.
 | 
					 | 
				
			||||||
        file_path = pathlib.Path(__file__)
 | 
					 | 
				
			||||||
        root_dir = file_path.parents[0 if file_path.name == "setup.py" else 2]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Read version from package metadata if it is installed.
 | 
					 | 
				
			||||||
        try:
 | 
					 | 
				
			||||||
            version = importlib.metadata.version("websockets")
 | 
					 | 
				
			||||||
        except ImportError:
 | 
					 | 
				
			||||||
            pass
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            # Check that this file belongs to the installed package.
 | 
					 | 
				
			||||||
            files = importlib.metadata.files("websockets")
 | 
					 | 
				
			||||||
            if files:
 | 
					 | 
				
			||||||
                version_files = [f for f in files if f.name == file_path.name]
 | 
					 | 
				
			||||||
                if version_files:
 | 
					 | 
				
			||||||
                    version_file = version_files[0]
 | 
					 | 
				
			||||||
                    if version_file.locate() == file_path:
 | 
					 | 
				
			||||||
                        return version
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Read version from git if available.
 | 
					 | 
				
			||||||
        try:
 | 
					 | 
				
			||||||
            description = subprocess.run(
 | 
					 | 
				
			||||||
                ["git", "describe", "--dirty", "--tags", "--long"],
 | 
					 | 
				
			||||||
                capture_output=True,
 | 
					 | 
				
			||||||
                cwd=root_dir,
 | 
					 | 
				
			||||||
                timeout=1,
 | 
					 | 
				
			||||||
                check=True,
 | 
					 | 
				
			||||||
                text=True,
 | 
					 | 
				
			||||||
            ).stdout.strip()
 | 
					 | 
				
			||||||
        # subprocess.run raises FileNotFoundError if git isn't on $PATH.
 | 
					 | 
				
			||||||
        except (
 | 
					 | 
				
			||||||
            FileNotFoundError,
 | 
					 | 
				
			||||||
            subprocess.CalledProcessError,
 | 
					 | 
				
			||||||
            subprocess.TimeoutExpired,
 | 
					 | 
				
			||||||
        ):
 | 
					 | 
				
			||||||
            pass
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            description_re = r"[0-9.]+-([0-9]+)-(g[0-9a-f]{7,}(?:-dirty)?)"
 | 
					 | 
				
			||||||
            match = re.fullmatch(description_re, description)
 | 
					 | 
				
			||||||
            if match is None:
 | 
					 | 
				
			||||||
                raise ValueError(f"Unexpected git description: {description}")
 | 
					 | 
				
			||||||
            distance, remainder = match.groups()
 | 
					 | 
				
			||||||
            remainder = remainder.replace("-", ".")  # required by PEP 440
 | 
					 | 
				
			||||||
            return f"{tag}.dev{distance}+{remainder}"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Avoid crashing if the development version cannot be determined.
 | 
					 | 
				
			||||||
        return f"{tag}.dev0+gunknown"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    version = get_version(tag)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def get_commit(tag: str, version: str) -> str:
 | 
					 | 
				
			||||||
        # Extract commit from version, falling back to tag if not available.
 | 
					 | 
				
			||||||
        version_re = r"[0-9.]+\.dev[0-9]+\+g([0-9a-f]{7,}|unknown)(?:\.dirty)?"
 | 
					 | 
				
			||||||
        match = re.fullmatch(version_re, version)
 | 
					 | 
				
			||||||
        if match is None:
 | 
					 | 
				
			||||||
            raise ValueError(f"Unexpected version: {version}")
 | 
					 | 
				
			||||||
        (commit,) = match.groups()
 | 
					 | 
				
			||||||
        return tag if commit == "unknown" else commit
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    commit = get_commit(tag, version)
 | 
					 | 
				
			||||||
		Reference in New Issue
	
	Block a user