generated from itdominator/Python-With-Gtk-Template
Reworked lsp manager plugin; removed websockets library
This commit is contained in:
parent
3edb89ad5c
commit
b31278e114
168
plugins/gtksourceview/lsp_manager/client_ipc.py
Normal file
168
plugins/gtksourceview/lsp_manager/client_ipc.py
Normal file
@ -0,0 +1,168 @@
|
||||
# Python imports
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import json
|
||||
import base64
|
||||
from multiprocessing.connection import Client
|
||||
from multiprocessing.connection import Listener
|
||||
|
||||
# Lib imports
|
||||
|
||||
# Application imports
|
||||
from .lsp_message_structs import LSPResponseRequest, LSPResponseNotification, get_message_obj
|
||||
|
||||
|
||||
|
||||
class ClientIPC:
|
||||
""" Create a Messenger so talk to LSP Manager. """
|
||||
def __init__(self, ipc_address: str = '127.0.0.1', conn_type: str = "socket"):
|
||||
self.is_ipc_alive = False
|
||||
self._ipc_port = 4848
|
||||
self._ipc_address = ipc_address
|
||||
self._conn_type = conn_type
|
||||
self._ipc_authkey = b'' + bytes(f'lsp-client-endpoint-ipc', 'utf-8')
|
||||
self._manager_ipc_authkey = b'' + bytes(f'lsp-manager-endpoint-ipc', 'utf-8')
|
||||
self._ipc_timeout = 15.0
|
||||
|
||||
if conn_type == "socket":
|
||||
self._ipc_address = f'/tmp/lsp-client-endpoint-ipc.sock'
|
||||
self._manager_ipc_address = f'/tmp/lsp-manager-endpoint-ipc.sock'
|
||||
elif conn_type == "full_network":
|
||||
self._ipc_address = '0.0.0.0'
|
||||
elif conn_type == "full_network_unsecured":
|
||||
self._ipc_authkey = None
|
||||
self._ipc_address = '0.0.0.0'
|
||||
elif conn_type == "local_network_unsecured":
|
||||
self._ipc_authkey = None
|
||||
|
||||
|
||||
def create_ipc_listener(self) -> None:
|
||||
if self._conn_type == "socket":
|
||||
if os.path.exists(self._ipc_address) and settings_manager.is_dirty_start():
|
||||
os.unlink(self._ipc_address)
|
||||
|
||||
listener = Listener(address=self._ipc_address, family="AF_UNIX", authkey=self._ipc_authkey)
|
||||
elif "unsecured" not in self._conn_type:
|
||||
listener = Listener((self._ipc_address, self._ipc_port), authkey=self._ipc_authkey)
|
||||
else:
|
||||
listener = Listener((self._ipc_address, self._ipc_port))
|
||||
|
||||
|
||||
self.is_ipc_alive = True
|
||||
self._run_ipc_loop(listener)
|
||||
|
||||
@daemon_threaded
|
||||
def _run_ipc_loop(self, listener) -> None:
|
||||
# NOTE: Not thread safe if using with Gtk. Need to import GLib and use idle_add
|
||||
while self.is_ipc_alive:
|
||||
try:
|
||||
conn = listener.accept()
|
||||
start_time = time.perf_counter()
|
||||
self._handle_ipc_message(conn, start_time)
|
||||
except Exception as e:
|
||||
logger.debug( repr(e) )
|
||||
|
||||
listener.close()
|
||||
|
||||
def _handle_ipc_message(self, conn, start_time) -> None:
|
||||
while self.is_ipc_alive:
|
||||
msg = conn.recv()
|
||||
logger.debug(msg)
|
||||
|
||||
if "MANAGER|" in msg:
|
||||
data = msg.split("MANAGER|")[1].strip()
|
||||
if data:
|
||||
data_str = base64.b64decode(data.encode("utf-8")).decode("utf-8")
|
||||
lsp_response = None
|
||||
keys = None
|
||||
|
||||
try:
|
||||
lsp_response = json.loads(data_str)
|
||||
keys = lsp_response.keys()
|
||||
except Exception as e:
|
||||
logger.debug( repr(e) )
|
||||
break
|
||||
|
||||
if "result" in keys:
|
||||
lsp_response = LSPResponseRequest(**get_message_obj(data))
|
||||
|
||||
if "method" in keys:
|
||||
lsp_response = LSPResponseNotification(**get_message_obj(data))
|
||||
|
||||
if "notification" in keys:
|
||||
...
|
||||
|
||||
if "response" in keys:
|
||||
...
|
||||
|
||||
if "ignorable" in keys:
|
||||
...
|
||||
|
||||
if lsp_response:
|
||||
self._event_system.emit("handle-lsp-message"), (lsp_response)
|
||||
|
||||
conn.close()
|
||||
break
|
||||
|
||||
if msg in ['close connection', 'close server']:
|
||||
conn.close()
|
||||
break
|
||||
|
||||
# NOTE: Not perfect but insures we don't lock up the connection for too long.
|
||||
end_time = time.perf_counter()
|
||||
if (end_time - start_time) > self._ipc_timeout:
|
||||
conn.close()
|
||||
break
|
||||
|
||||
|
||||
def send_manager_ipc_message(self, message: str) -> None:
|
||||
try:
|
||||
if self._conn_type == "socket":
|
||||
conn = Client(address=self._manager_ipc_address, family="AF_UNIX", authkey=self._manager_ipc_authkey)
|
||||
elif "unsecured" not in self._conn_type:
|
||||
conn = Client((self._ipc_address, self._ipc_port), authkey=self._ipc_authkey)
|
||||
else:
|
||||
conn = Client((self._ipc_address, self._ipc_port))
|
||||
|
||||
conn.send( f"CLIENT|{ base64.b64encode(message.encode("utf-8")).decode("utf-8") }" )
|
||||
conn.close()
|
||||
except ConnectionRefusedError as e:
|
||||
logger.error("Connection refused...")
|
||||
except Exception as e:
|
||||
logger.error( repr(e) )
|
||||
|
||||
|
||||
def send_ipc_message(self, message: str = "Empty Data...") -> None:
|
||||
try:
|
||||
if self._conn_type == "socket":
|
||||
conn = Client(address=self._ipc_address, family="AF_UNIX", authkey=self._ipc_authkey)
|
||||
elif "unsecured" not in self._conn_type:
|
||||
conn = Client((self._ipc_address, self._ipc_port), authkey=self._ipc_authkey)
|
||||
else:
|
||||
conn = Client((self._ipc_address, self._ipc_port))
|
||||
|
||||
conn.send(message)
|
||||
conn.close()
|
||||
except ConnectionRefusedError as e:
|
||||
logger.error("Connection refused...")
|
||||
except Exception as e:
|
||||
logger.error( repr(e) )
|
||||
|
||||
def send_test_ipc_message(self, message: str = "Empty Data...") -> None:
|
||||
try:
|
||||
if self._conn_type == "socket":
|
||||
conn = Client(address=self._ipc_address, family="AF_UNIX", authkey=self._ipc_authkey)
|
||||
elif "unsecured" not in self._conn_type:
|
||||
conn = Client((self._ipc_address, self._ipc_port), authkey=self._ipc_authkey)
|
||||
else:
|
||||
conn = Client((self._ipc_address, self._ipc_port))
|
||||
|
||||
conn.send(message)
|
||||
conn.close()
|
||||
except ConnectionRefusedError as e:
|
||||
if self._conn_type == "socket":
|
||||
logger.error("LSP Socket no longer valid.... Removing.")
|
||||
os.unlink(self._ipc_address)
|
||||
except Exception as e:
|
||||
logger.error( repr(e) )
|
41
plugins/gtksourceview/lsp_manager/lsp_message_structs.py
Normal file
41
plugins/gtksourceview/lsp_manager/lsp_message_structs.py
Normal file
@ -0,0 +1,41 @@
|
||||
# Python imports
|
||||
from dataclasses import dataclass
|
||||
import json
|
||||
|
||||
# Lib imports
|
||||
|
||||
# Application imports
|
||||
|
||||
|
||||
|
||||
def get_message_obj(data: str):
|
||||
return json.loads(data)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LSPResponseRequest(object):
|
||||
"""
|
||||
Constructs a new LSP Response Request instance.
|
||||
|
||||
:param id result: The id of the given message.
|
||||
:param dict result: The arguments of the given method.
|
||||
"""
|
||||
jsonrpc: str
|
||||
id: int
|
||||
result: dict
|
||||
|
||||
@dataclass
|
||||
class LSPResponseNotification(object):
|
||||
"""
|
||||
Constructs a new LSP Response Notification instance.
|
||||
|
||||
:param str method: The type of lsp notification being made.
|
||||
:params dict result: The arguments of the given method.
|
||||
"""
|
||||
jsonrpc: str
|
||||
method: str
|
||||
params: dict
|
||||
|
||||
|
||||
class LSPResponseTypes(LSPResponseRequest, LSPResponseNotification):
|
||||
...
|
@ -1,4 +1,5 @@
|
||||
# Python imports
|
||||
import signal
|
||||
import subprocess
|
||||
import json
|
||||
|
||||
@ -9,7 +10,7 @@ from gi.repository import Gtk
|
||||
|
||||
# Application imports
|
||||
from plugins.plugin_base import PluginBase
|
||||
from .websockets.sync.client import connect
|
||||
from .client_ipc import ClientIPC
|
||||
|
||||
|
||||
|
||||
@ -90,7 +91,27 @@ class Plugin(PluginBase):
|
||||
|
||||
def start_lsp_manager(self, button):
|
||||
if self.lsp_manager_proc: return
|
||||
self.lsp_manager_proc = subprocess.Popen(["lsp-manager"])
|
||||
self.lsp_manager_proc = subprocess.Popen(["python", "/opt/lsp-manager.zip"])
|
||||
# self.lsp_manager_proc = subprocess.Popen(["lsp-manager"])
|
||||
self._load_client_ipc_server()
|
||||
|
||||
def _load_client_ipc_server(self):
|
||||
self.client_ipc = ClientIPC()
|
||||
self._ipc_realization_check(self.client_ipc)
|
||||
|
||||
if not self.client_ipc.is_ipc_alive:
|
||||
raise AppLaunchException(f"LSP IPC Server Already Exists...")
|
||||
|
||||
def _ipc_realization_check(self, ipc_server):
|
||||
try:
|
||||
ipc_server.create_ipc_listener()
|
||||
except Exception:
|
||||
ipc_server.send_test_ipc_message()
|
||||
|
||||
try:
|
||||
ipc_server.create_ipc_listener()
|
||||
except Exception as e:
|
||||
...
|
||||
|
||||
def stop_lsp_manager(self, button = None):
|
||||
if not self.lsp_manager_proc: return
|
||||
@ -99,6 +120,7 @@ class Plugin(PluginBase):
|
||||
return
|
||||
|
||||
self.lsp_manager_proc.terminate()
|
||||
self.client_ipc.is_ipc_alive = False
|
||||
self.lsp_manager_proc = None
|
||||
|
||||
def _lsp_did_open(self, language_id, uri, text):
|
||||
@ -122,11 +144,12 @@ class Plugin(PluginBase):
|
||||
def _lsp_did_close(self):
|
||||
if not self.lsp_manager_proc: return
|
||||
|
||||
def _lsp_did_change(self, language_id, buffer):
|
||||
def _lsp_did_change(self, language_id, uri, buffer):
|
||||
if not self.lsp_manager_proc: return
|
||||
|
||||
iter = buffer.get_iter_at_mark( buffer.get_insert() )
|
||||
line = iter.get_line()
|
||||
column = iter.get_line_offset()
|
||||
start = iter.copy()
|
||||
end = iter.copy()
|
||||
|
||||
@ -139,9 +162,9 @@ class Plugin(PluginBase):
|
||||
"method": "textDocument/didChange",
|
||||
"language_id": language_id,
|
||||
"uri": uri,
|
||||
"text": text
|
||||
"line": -1,
|
||||
"column": -1,
|
||||
"text": text,
|
||||
"line": line,
|
||||
"column": column,
|
||||
"char": ""
|
||||
}
|
||||
|
||||
@ -190,11 +213,5 @@ class Plugin(PluginBase):
|
||||
|
||||
self.send_message(data)
|
||||
|
||||
|
||||
def send_message(self, data: dict):
|
||||
with connect(f"ws://{ self.ws_config['host'] }:{ self.ws_config['port'] }") as websocket:
|
||||
websocket.send(
|
||||
json.dumps(data)
|
||||
)
|
||||
message = websocket.recv()
|
||||
print(f"Received: {message}")
|
||||
self.client_ipc.send_manager_ipc_message( json.dumps(data) )
|
@ -1,199 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from .imports import lazy_import
|
||||
from .version import version as __version__ # noqa: F401
|
||||
|
||||
|
||||
__all__ = [
|
||||
# .client
|
||||
"ClientProtocol",
|
||||
# .datastructures
|
||||
"Headers",
|
||||
"HeadersLike",
|
||||
"MultipleValuesError",
|
||||
# .exceptions
|
||||
"ConcurrencyError",
|
||||
"ConnectionClosed",
|
||||
"ConnectionClosedError",
|
||||
"ConnectionClosedOK",
|
||||
"DuplicateParameter",
|
||||
"InvalidHandshake",
|
||||
"InvalidHeader",
|
||||
"InvalidHeaderFormat",
|
||||
"InvalidHeaderValue",
|
||||
"InvalidOrigin",
|
||||
"InvalidParameterName",
|
||||
"InvalidParameterValue",
|
||||
"InvalidState",
|
||||
"InvalidStatus",
|
||||
"InvalidUpgrade",
|
||||
"InvalidURI",
|
||||
"NegotiationError",
|
||||
"PayloadTooBig",
|
||||
"ProtocolError",
|
||||
"SecurityError",
|
||||
"WebSocketException",
|
||||
"WebSocketProtocolError",
|
||||
# .legacy.auth
|
||||
"BasicAuthWebSocketServerProtocol",
|
||||
"basic_auth_protocol_factory",
|
||||
# .legacy.client
|
||||
"WebSocketClientProtocol",
|
||||
"connect",
|
||||
"unix_connect",
|
||||
# .legacy.exceptions
|
||||
"AbortHandshake",
|
||||
"InvalidMessage",
|
||||
"InvalidStatusCode",
|
||||
"RedirectHandshake",
|
||||
# .legacy.protocol
|
||||
"WebSocketCommonProtocol",
|
||||
# .legacy.server
|
||||
"WebSocketServer",
|
||||
"WebSocketServerProtocol",
|
||||
"broadcast",
|
||||
"serve",
|
||||
"unix_serve",
|
||||
# .server
|
||||
"ServerProtocol",
|
||||
# .typing
|
||||
"Data",
|
||||
"ExtensionName",
|
||||
"ExtensionParameter",
|
||||
"LoggerLike",
|
||||
"StatusLike",
|
||||
"Origin",
|
||||
"Subprotocol",
|
||||
]
|
||||
|
||||
# When type checking, import non-deprecated aliases eagerly. Else, import on demand.
|
||||
if typing.TYPE_CHECKING:
|
||||
from .client import ClientProtocol
|
||||
from .datastructures import Headers, HeadersLike, MultipleValuesError
|
||||
from .exceptions import (
|
||||
ConcurrencyError,
|
||||
ConnectionClosed,
|
||||
ConnectionClosedError,
|
||||
ConnectionClosedOK,
|
||||
DuplicateParameter,
|
||||
InvalidHandshake,
|
||||
InvalidHeader,
|
||||
InvalidHeaderFormat,
|
||||
InvalidHeaderValue,
|
||||
InvalidOrigin,
|
||||
InvalidParameterName,
|
||||
InvalidParameterValue,
|
||||
InvalidState,
|
||||
InvalidStatus,
|
||||
InvalidUpgrade,
|
||||
InvalidURI,
|
||||
NegotiationError,
|
||||
PayloadTooBig,
|
||||
ProtocolError,
|
||||
SecurityError,
|
||||
WebSocketException,
|
||||
WebSocketProtocolError,
|
||||
)
|
||||
from .legacy.auth import (
|
||||
BasicAuthWebSocketServerProtocol,
|
||||
basic_auth_protocol_factory,
|
||||
)
|
||||
from .legacy.client import WebSocketClientProtocol, connect, unix_connect
|
||||
from .legacy.exceptions import (
|
||||
AbortHandshake,
|
||||
InvalidMessage,
|
||||
InvalidStatusCode,
|
||||
RedirectHandshake,
|
||||
)
|
||||
from .legacy.protocol import WebSocketCommonProtocol
|
||||
from .legacy.server import (
|
||||
WebSocketServer,
|
||||
WebSocketServerProtocol,
|
||||
broadcast,
|
||||
serve,
|
||||
unix_serve,
|
||||
)
|
||||
from .server import ServerProtocol
|
||||
from .typing import (
|
||||
Data,
|
||||
ExtensionName,
|
||||
ExtensionParameter,
|
||||
LoggerLike,
|
||||
Origin,
|
||||
StatusLike,
|
||||
Subprotocol,
|
||||
)
|
||||
else:
|
||||
lazy_import(
|
||||
globals(),
|
||||
aliases={
|
||||
# .client
|
||||
"ClientProtocol": ".client",
|
||||
# .datastructures
|
||||
"Headers": ".datastructures",
|
||||
"HeadersLike": ".datastructures",
|
||||
"MultipleValuesError": ".datastructures",
|
||||
# .exceptions
|
||||
"ConcurrencyError": ".exceptions",
|
||||
"ConnectionClosed": ".exceptions",
|
||||
"ConnectionClosedError": ".exceptions",
|
||||
"ConnectionClosedOK": ".exceptions",
|
||||
"DuplicateParameter": ".exceptions",
|
||||
"InvalidHandshake": ".exceptions",
|
||||
"InvalidHeader": ".exceptions",
|
||||
"InvalidHeaderFormat": ".exceptions",
|
||||
"InvalidHeaderValue": ".exceptions",
|
||||
"InvalidOrigin": ".exceptions",
|
||||
"InvalidParameterName": ".exceptions",
|
||||
"InvalidParameterValue": ".exceptions",
|
||||
"InvalidState": ".exceptions",
|
||||
"InvalidStatus": ".exceptions",
|
||||
"InvalidUpgrade": ".exceptions",
|
||||
"InvalidURI": ".exceptions",
|
||||
"NegotiationError": ".exceptions",
|
||||
"PayloadTooBig": ".exceptions",
|
||||
"ProtocolError": ".exceptions",
|
||||
"SecurityError": ".exceptions",
|
||||
"WebSocketException": ".exceptions",
|
||||
"WebSocketProtocolError": ".exceptions",
|
||||
# .legacy.auth
|
||||
"BasicAuthWebSocketServerProtocol": ".legacy.auth",
|
||||
"basic_auth_protocol_factory": ".legacy.auth",
|
||||
# .legacy.client
|
||||
"WebSocketClientProtocol": ".legacy.client",
|
||||
"connect": ".legacy.client",
|
||||
"unix_connect": ".legacy.client",
|
||||
# .legacy.exceptions
|
||||
"AbortHandshake": ".legacy.exceptions",
|
||||
"InvalidMessage": ".legacy.exceptions",
|
||||
"InvalidStatusCode": ".legacy.exceptions",
|
||||
"RedirectHandshake": ".legacy.exceptions",
|
||||
# .legacy.protocol
|
||||
"WebSocketCommonProtocol": ".legacy.protocol",
|
||||
# .legacy.server
|
||||
"WebSocketServer": ".legacy.server",
|
||||
"WebSocketServerProtocol": ".legacy.server",
|
||||
"broadcast": ".legacy.server",
|
||||
"serve": ".legacy.server",
|
||||
"unix_serve": ".legacy.server",
|
||||
# .server
|
||||
"ServerProtocol": ".server",
|
||||
# .typing
|
||||
"Data": ".typing",
|
||||
"ExtensionName": ".typing",
|
||||
"ExtensionParameter": ".typing",
|
||||
"LoggerLike": ".typing",
|
||||
"Origin": ".typing",
|
||||
"StatusLike": ".typing",
|
||||
"Subprotocol": ".typing",
|
||||
},
|
||||
deprecated_aliases={
|
||||
# deprecated in 9.0 - 2021-09-01
|
||||
"framing": ".legacy",
|
||||
"handshake": ".legacy",
|
||||
"parse_uri": ".uri",
|
||||
"WebSocketURI": ".uri",
|
||||
},
|
||||
)
|
@ -1,159 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
|
||||
|
||||
try:
|
||||
import readline # noqa: F401
|
||||
except ImportError: # Windows has no `readline` normally
|
||||
pass
|
||||
|
||||
from .sync.client import ClientConnection, connect
|
||||
from .version import version as websockets_version
|
||||
|
||||
|
||||
if sys.platform == "win32":
|
||||
|
||||
def win_enable_vt100() -> None:
|
||||
"""
|
||||
Enable VT-100 for console output on Windows.
|
||||
|
||||
See also https://github.com/python/cpython/issues/73245.
|
||||
|
||||
"""
|
||||
import ctypes
|
||||
|
||||
STD_OUTPUT_HANDLE = ctypes.c_uint(-11)
|
||||
INVALID_HANDLE_VALUE = ctypes.c_uint(-1)
|
||||
ENABLE_VIRTUAL_TERMINAL_PROCESSING = 0x004
|
||||
|
||||
handle = ctypes.windll.kernel32.GetStdHandle(STD_OUTPUT_HANDLE)
|
||||
if handle == INVALID_HANDLE_VALUE:
|
||||
raise RuntimeError("unable to obtain stdout handle")
|
||||
|
||||
cur_mode = ctypes.c_uint()
|
||||
if ctypes.windll.kernel32.GetConsoleMode(handle, ctypes.byref(cur_mode)) == 0:
|
||||
raise RuntimeError("unable to query current console mode")
|
||||
|
||||
# ctypes ints lack support for the required bit-OR operation.
|
||||
# Temporarily convert to Py int, do the OR and convert back.
|
||||
py_int_mode = int.from_bytes(cur_mode, sys.byteorder)
|
||||
new_mode = ctypes.c_uint(py_int_mode | ENABLE_VIRTUAL_TERMINAL_PROCESSING)
|
||||
|
||||
if ctypes.windll.kernel32.SetConsoleMode(handle, new_mode) == 0:
|
||||
raise RuntimeError("unable to set console mode")
|
||||
|
||||
|
||||
def print_during_input(string: str) -> None:
|
||||
sys.stdout.write(
|
||||
# Save cursor position
|
||||
"\N{ESC}7"
|
||||
# Add a new line
|
||||
"\N{LINE FEED}"
|
||||
# Move cursor up
|
||||
"\N{ESC}[A"
|
||||
# Insert blank line, scroll last line down
|
||||
"\N{ESC}[L"
|
||||
# Print string in the inserted blank line
|
||||
f"{string}\N{LINE FEED}"
|
||||
# Restore cursor position
|
||||
"\N{ESC}8"
|
||||
# Move cursor down
|
||||
"\N{ESC}[B"
|
||||
)
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
def print_over_input(string: str) -> None:
|
||||
sys.stdout.write(
|
||||
# Move cursor to beginning of line
|
||||
"\N{CARRIAGE RETURN}"
|
||||
# Delete current line
|
||||
"\N{ESC}[K"
|
||||
# Print string
|
||||
f"{string}\N{LINE FEED}"
|
||||
)
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
def print_incoming_messages(websocket: ClientConnection, stop: threading.Event) -> None:
|
||||
for message in websocket:
|
||||
if isinstance(message, str):
|
||||
print_during_input("< " + message)
|
||||
else:
|
||||
print_during_input("< (binary) " + message.hex())
|
||||
if not stop.is_set():
|
||||
# When the server closes the connection, raise KeyboardInterrupt
|
||||
# in the main thread to exit the program.
|
||||
if sys.platform == "win32":
|
||||
ctrl_c = signal.CTRL_C_EVENT
|
||||
else:
|
||||
ctrl_c = signal.SIGINT
|
||||
os.kill(os.getpid(), ctrl_c)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
# Parse command line arguments.
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="python -m websockets",
|
||||
description="Interactive WebSocket client.",
|
||||
add_help=False,
|
||||
)
|
||||
group = parser.add_mutually_exclusive_group()
|
||||
group.add_argument("--version", action="store_true")
|
||||
group.add_argument("uri", metavar="<uri>", nargs="?")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.version:
|
||||
print(f"websockets {websockets_version}")
|
||||
return
|
||||
|
||||
if args.uri is None:
|
||||
parser.error("the following arguments are required: <uri>")
|
||||
|
||||
# If we're on Windows, enable VT100 terminal support.
|
||||
if sys.platform == "win32":
|
||||
try:
|
||||
win_enable_vt100()
|
||||
except RuntimeError as exc:
|
||||
sys.stderr.write(
|
||||
f"Unable to set terminal to VT100 mode. This is only "
|
||||
f"supported since Win10 anniversary update. Expect "
|
||||
f"weird symbols on the terminal.\nError: {exc}\n"
|
||||
)
|
||||
sys.stderr.flush()
|
||||
|
||||
try:
|
||||
websocket = connect(args.uri)
|
||||
except Exception as exc:
|
||||
print(f"Failed to connect to {args.uri}: {exc}.")
|
||||
sys.exit(1)
|
||||
else:
|
||||
print(f"Connected to {args.uri}.")
|
||||
|
||||
stop = threading.Event()
|
||||
|
||||
# Start the thread that reads messages from the connection.
|
||||
thread = threading.Thread(target=print_incoming_messages, args=(websocket, stop))
|
||||
thread.start()
|
||||
|
||||
# Read from stdin in the main thread in order to receive signals.
|
||||
try:
|
||||
while True:
|
||||
# Since there's no size limit, put_nowait is identical to put.
|
||||
message = input("> ")
|
||||
websocket.send(message)
|
||||
except (KeyboardInterrupt, EOFError): # ^C, ^D
|
||||
stop.set()
|
||||
websocket.close()
|
||||
print_over_input("Connection closed.")
|
||||
|
||||
thread.join()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,282 +0,0 @@
|
||||
# From https://github.com/aio-libs/async-timeout/blob/master/async_timeout/__init__.py
|
||||
# Licensed under the Apache License (Apache-2.0)
|
||||
|
||||
import asyncio
|
||||
import enum
|
||||
import sys
|
||||
import warnings
|
||||
from types import TracebackType
|
||||
from typing import Optional, Type
|
||||
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
from typing import final
|
||||
else:
|
||||
# From https://github.com/python/typing_extensions/blob/main/src/typing_extensions.py
|
||||
# Licensed under the Python Software Foundation License (PSF-2.0)
|
||||
|
||||
# @final exists in 3.8+, but we backport it for all versions
|
||||
# before 3.11 to keep support for the __final__ attribute.
|
||||
# See https://bugs.python.org/issue46342
|
||||
def final(f):
|
||||
"""This decorator can be used to indicate to type checkers that
|
||||
the decorated method cannot be overridden, and decorated class
|
||||
cannot be subclassed. For example:
|
||||
|
||||
class Base:
|
||||
@final
|
||||
def done(self) -> None:
|
||||
...
|
||||
class Sub(Base):
|
||||
def done(self) -> None: # Error reported by type checker
|
||||
...
|
||||
@final
|
||||
class Leaf:
|
||||
...
|
||||
class Other(Leaf): # Error reported by type checker
|
||||
...
|
||||
|
||||
There is no runtime checking of these properties. The decorator
|
||||
sets the ``__final__`` attribute to ``True`` on the decorated object
|
||||
to allow runtime introspection.
|
||||
"""
|
||||
try:
|
||||
f.__final__ = True
|
||||
except (AttributeError, TypeError):
|
||||
# Skip the attribute silently if it is not writable.
|
||||
# AttributeError happens if the object has __slots__ or a
|
||||
# read-only property, TypeError if it's a builtin class.
|
||||
pass
|
||||
return f
|
||||
|
||||
# End https://github.com/python/typing_extensions/blob/main/src/typing_extensions.py
|
||||
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
|
||||
def _uncancel_task(task: "asyncio.Task[object]") -> None:
|
||||
task.uncancel()
|
||||
|
||||
else:
|
||||
|
||||
def _uncancel_task(task: "asyncio.Task[object]") -> None:
|
||||
pass
|
||||
|
||||
|
||||
__version__ = "4.0.3"
|
||||
|
||||
|
||||
__all__ = ("timeout", "timeout_at", "Timeout")
|
||||
|
||||
|
||||
def timeout(delay: Optional[float]) -> "Timeout":
|
||||
"""timeout context manager.
|
||||
|
||||
Useful in cases when you want to apply timeout logic around block
|
||||
of code or in cases when asyncio.wait_for is not suitable. For example:
|
||||
|
||||
>>> async with timeout(0.001):
|
||||
... async with aiohttp.get('https://github.com') as r:
|
||||
... await r.text()
|
||||
|
||||
|
||||
delay - value in seconds or None to disable timeout logic
|
||||
"""
|
||||
loop = asyncio.get_running_loop()
|
||||
if delay is not None:
|
||||
deadline = loop.time() + delay # type: Optional[float]
|
||||
else:
|
||||
deadline = None
|
||||
return Timeout(deadline, loop)
|
||||
|
||||
|
||||
def timeout_at(deadline: Optional[float]) -> "Timeout":
|
||||
"""Schedule the timeout at absolute time.
|
||||
|
||||
deadline argument points on the time in the same clock system
|
||||
as loop.time().
|
||||
|
||||
Please note: it is not POSIX time but a time with
|
||||
undefined starting base, e.g. the time of the system power on.
|
||||
|
||||
>>> async with timeout_at(loop.time() + 10):
|
||||
... async with aiohttp.get('https://github.com') as r:
|
||||
... await r.text()
|
||||
|
||||
|
||||
"""
|
||||
loop = asyncio.get_running_loop()
|
||||
return Timeout(deadline, loop)
|
||||
|
||||
|
||||
class _State(enum.Enum):
|
||||
INIT = "INIT"
|
||||
ENTER = "ENTER"
|
||||
TIMEOUT = "TIMEOUT"
|
||||
EXIT = "EXIT"
|
||||
|
||||
|
||||
@final
|
||||
class Timeout:
|
||||
# Internal class, please don't instantiate it directly
|
||||
# Use timeout() and timeout_at() public factories instead.
|
||||
#
|
||||
# Implementation note: `async with timeout()` is preferred
|
||||
# over `with timeout()`.
|
||||
# While technically the Timeout class implementation
|
||||
# doesn't need to be async at all,
|
||||
# the `async with` statement explicitly points that
|
||||
# the context manager should be used from async function context.
|
||||
#
|
||||
# This design allows to avoid many silly misusages.
|
||||
#
|
||||
# TimeoutError is raised immediately when scheduled
|
||||
# if the deadline is passed.
|
||||
# The purpose is to time out as soon as possible
|
||||
# without waiting for the next await expression.
|
||||
|
||||
__slots__ = ("_deadline", "_loop", "_state", "_timeout_handler", "_task")
|
||||
|
||||
def __init__(
|
||||
self, deadline: Optional[float], loop: asyncio.AbstractEventLoop
|
||||
) -> None:
|
||||
self._loop = loop
|
||||
self._state = _State.INIT
|
||||
|
||||
self._task: Optional["asyncio.Task[object]"] = None
|
||||
self._timeout_handler = None # type: Optional[asyncio.Handle]
|
||||
if deadline is None:
|
||||
self._deadline = None # type: Optional[float]
|
||||
else:
|
||||
self.update(deadline)
|
||||
|
||||
def __enter__(self) -> "Timeout":
|
||||
warnings.warn(
|
||||
"with timeout() is deprecated, use async with timeout() instead",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
self._do_enter()
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType],
|
||||
) -> Optional[bool]:
|
||||
self._do_exit(exc_type)
|
||||
return None
|
||||
|
||||
async def __aenter__(self) -> "Timeout":
|
||||
self._do_enter()
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType],
|
||||
) -> Optional[bool]:
|
||||
self._do_exit(exc_type)
|
||||
return None
|
||||
|
||||
@property
|
||||
def expired(self) -> bool:
|
||||
"""Is timeout expired during execution?"""
|
||||
return self._state == _State.TIMEOUT
|
||||
|
||||
@property
|
||||
def deadline(self) -> Optional[float]:
|
||||
return self._deadline
|
||||
|
||||
def reject(self) -> None:
|
||||
"""Reject scheduled timeout if any."""
|
||||
# cancel is maybe better name but
|
||||
# task.cancel() raises CancelledError in asyncio world.
|
||||
if self._state not in (_State.INIT, _State.ENTER):
|
||||
raise RuntimeError(f"invalid state {self._state.value}")
|
||||
self._reject()
|
||||
|
||||
def _reject(self) -> None:
|
||||
self._task = None
|
||||
if self._timeout_handler is not None:
|
||||
self._timeout_handler.cancel()
|
||||
self._timeout_handler = None
|
||||
|
||||
def shift(self, delay: float) -> None:
|
||||
"""Advance timeout on delay seconds.
|
||||
|
||||
The delay can be negative.
|
||||
|
||||
Raise RuntimeError if shift is called when deadline is not scheduled
|
||||
"""
|
||||
deadline = self._deadline
|
||||
if deadline is None:
|
||||
raise RuntimeError("cannot shift timeout if deadline is not scheduled")
|
||||
self.update(deadline + delay)
|
||||
|
||||
def update(self, deadline: float) -> None:
|
||||
"""Set deadline to absolute value.
|
||||
|
||||
deadline argument points on the time in the same clock system
|
||||
as loop.time().
|
||||
|
||||
If new deadline is in the past the timeout is raised immediately.
|
||||
|
||||
Please note: it is not POSIX time but a time with
|
||||
undefined starting base, e.g. the time of the system power on.
|
||||
"""
|
||||
if self._state == _State.EXIT:
|
||||
raise RuntimeError("cannot reschedule after exit from context manager")
|
||||
if self._state == _State.TIMEOUT:
|
||||
raise RuntimeError("cannot reschedule expired timeout")
|
||||
if self._timeout_handler is not None:
|
||||
self._timeout_handler.cancel()
|
||||
self._deadline = deadline
|
||||
if self._state != _State.INIT:
|
||||
self._reschedule()
|
||||
|
||||
def _reschedule(self) -> None:
|
||||
assert self._state == _State.ENTER
|
||||
deadline = self._deadline
|
||||
if deadline is None:
|
||||
return
|
||||
|
||||
now = self._loop.time()
|
||||
if self._timeout_handler is not None:
|
||||
self._timeout_handler.cancel()
|
||||
|
||||
self._task = asyncio.current_task()
|
||||
if deadline <= now:
|
||||
self._timeout_handler = self._loop.call_soon(self._on_timeout)
|
||||
else:
|
||||
self._timeout_handler = self._loop.call_at(deadline, self._on_timeout)
|
||||
|
||||
def _do_enter(self) -> None:
|
||||
if self._state != _State.INIT:
|
||||
raise RuntimeError(f"invalid state {self._state.value}")
|
||||
self._state = _State.ENTER
|
||||
self._reschedule()
|
||||
|
||||
def _do_exit(self, exc_type: Optional[Type[BaseException]]) -> None:
|
||||
if exc_type is asyncio.CancelledError and self._state == _State.TIMEOUT:
|
||||
assert self._task is not None
|
||||
_uncancel_task(self._task)
|
||||
self._timeout_handler = None
|
||||
self._task = None
|
||||
raise asyncio.TimeoutError
|
||||
# timeout has not expired
|
||||
self._state = _State.EXIT
|
||||
self._reject()
|
||||
return None
|
||||
|
||||
def _on_timeout(self) -> None:
|
||||
assert self._task is not None
|
||||
self._task.cancel()
|
||||
self._state = _State.TIMEOUT
|
||||
# drop the reference early
|
||||
self._timeout_handler = None
|
||||
|
||||
|
||||
# End https://github.com/aio-libs/async-timeout/blob/master/async_timeout/__init__.py
|
@ -1,561 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import urllib.parse
|
||||
from types import TracebackType
|
||||
from typing import Any, AsyncIterator, Callable, Generator, Sequence
|
||||
|
||||
from ..client import ClientProtocol, backoff
|
||||
from ..datastructures import HeadersLike
|
||||
from ..exceptions import InvalidStatus, SecurityError
|
||||
from ..extensions.base import ClientExtensionFactory
|
||||
from ..extensions.permessage_deflate import enable_client_permessage_deflate
|
||||
from ..headers import validate_subprotocols
|
||||
from ..http11 import USER_AGENT, Response
|
||||
from ..protocol import CONNECTING, Event
|
||||
from ..typing import LoggerLike, Origin, Subprotocol
|
||||
from ..uri import WebSocketURI, parse_uri
|
||||
from .compatibility import TimeoutError, asyncio_timeout
|
||||
from .connection import Connection
|
||||
|
||||
|
||||
__all__ = ["connect", "unix_connect", "ClientConnection"]
|
||||
|
||||
MAX_REDIRECTS = int(os.environ.get("WEBSOCKETS_MAX_REDIRECTS", "10"))
|
||||
|
||||
|
||||
class ClientConnection(Connection):
|
||||
"""
|
||||
:mod:`asyncio` implementation of a WebSocket client connection.
|
||||
|
||||
:class:`ClientConnection` provides :meth:`recv` and :meth:`send` coroutines
|
||||
for receiving and sending messages.
|
||||
|
||||
It supports asynchronous iteration to receive messages::
|
||||
|
||||
async for message in websocket:
|
||||
await process(message)
|
||||
|
||||
The iterator exits normally when the connection is closed with close code
|
||||
1000 (OK) or 1001 (going away) or without a close code. It raises a
|
||||
:exc:`~websockets.exceptions.ConnectionClosedError` when the connection is
|
||||
closed with any other code.
|
||||
|
||||
The ``ping_interval``, ``ping_timeout``, ``close_timeout``, ``max_queue``,
|
||||
and ``write_limit`` arguments the same meaning as in :func:`connect`.
|
||||
|
||||
Args:
|
||||
protocol: Sans-I/O connection.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
protocol: ClientProtocol,
|
||||
*,
|
||||
ping_interval: float | None = 20,
|
||||
ping_timeout: float | None = 20,
|
||||
close_timeout: float | None = 10,
|
||||
max_queue: int | tuple[int, int | None] = 16,
|
||||
write_limit: int | tuple[int, int | None] = 2**15,
|
||||
) -> None:
|
||||
self.protocol: ClientProtocol
|
||||
super().__init__(
|
||||
protocol,
|
||||
ping_interval=ping_interval,
|
||||
ping_timeout=ping_timeout,
|
||||
close_timeout=close_timeout,
|
||||
max_queue=max_queue,
|
||||
write_limit=write_limit,
|
||||
)
|
||||
self.response_rcvd: asyncio.Future[None] = self.loop.create_future()
|
||||
|
||||
async def handshake(
|
||||
self,
|
||||
additional_headers: HeadersLike | None = None,
|
||||
user_agent_header: str | None = USER_AGENT,
|
||||
) -> None:
|
||||
"""
|
||||
Perform the opening handshake.
|
||||
|
||||
"""
|
||||
async with self.send_context(expected_state=CONNECTING):
|
||||
self.request = self.protocol.connect()
|
||||
if additional_headers is not None:
|
||||
self.request.headers.update(additional_headers)
|
||||
if user_agent_header:
|
||||
self.request.headers["User-Agent"] = user_agent_header
|
||||
self.protocol.send_request(self.request)
|
||||
|
||||
await asyncio.wait(
|
||||
[self.response_rcvd, self.connection_lost_waiter],
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
|
||||
# self.protocol.handshake_exc is always set when the connection is lost
|
||||
# before receiving a response, when the response cannot be parsed, or
|
||||
# when the response fails the handshake.
|
||||
|
||||
if self.protocol.handshake_exc is not None:
|
||||
raise self.protocol.handshake_exc
|
||||
|
||||
def process_event(self, event: Event) -> None:
|
||||
"""
|
||||
Process one incoming event.
|
||||
|
||||
"""
|
||||
# First event - handshake response.
|
||||
if self.response is None:
|
||||
assert isinstance(event, Response)
|
||||
self.response = event
|
||||
self.response_rcvd.set_result(None)
|
||||
# Later events - frames.
|
||||
else:
|
||||
super().process_event(event)
|
||||
|
||||
|
||||
def process_exception(exc: Exception) -> Exception | None:
|
||||
"""
|
||||
Determine whether a connection error is retryable or fatal.
|
||||
|
||||
When reconnecting automatically with ``async for ... in connect(...)``, if a
|
||||
connection attempt fails, :func:`process_exception` is called to determine
|
||||
whether to retry connecting or to raise the exception.
|
||||
|
||||
This function defines the default behavior, which is to retry on:
|
||||
|
||||
* :exc:`EOFError`, :exc:`OSError`, :exc:`asyncio.TimeoutError`: network
|
||||
errors;
|
||||
* :exc:`~websockets.exceptions.InvalidStatus` when the status code is 500,
|
||||
502, 503, or 504: server or proxy errors.
|
||||
|
||||
All other exceptions are considered fatal.
|
||||
|
||||
You can change this behavior with the ``process_exception`` argument of
|
||||
:func:`connect`.
|
||||
|
||||
Return :obj:`None` if the exception is retryable i.e. when the error could
|
||||
be transient and trying to reconnect with the same parameters could succeed.
|
||||
The exception will be logged at the ``INFO`` level.
|
||||
|
||||
Return an exception, either ``exc`` or a new exception, if the exception is
|
||||
fatal i.e. when trying to reconnect will most likely produce the same error.
|
||||
That exception will be raised, breaking out of the retry loop.
|
||||
|
||||
"""
|
||||
if isinstance(exc, (EOFError, OSError, asyncio.TimeoutError)):
|
||||
return None
|
||||
if isinstance(exc, InvalidStatus) and exc.response.status_code in [
|
||||
500, # Internal Server Error
|
||||
502, # Bad Gateway
|
||||
503, # Service Unavailable
|
||||
504, # Gateway Timeout
|
||||
]:
|
||||
return None
|
||||
return exc
|
||||
|
||||
|
||||
# This is spelled in lower case because it's exposed as a callable in the API.
|
||||
class connect:
|
||||
"""
|
||||
Connect to the WebSocket server at ``uri``.
|
||||
|
||||
This coroutine returns a :class:`ClientConnection` instance, which you can
|
||||
use to send and receive messages.
|
||||
|
||||
:func:`connect` may be used as an asynchronous context manager::
|
||||
|
||||
from websockets.asyncio.client import connect
|
||||
|
||||
async with connect(...) as websocket:
|
||||
...
|
||||
|
||||
The connection is closed automatically when exiting the context.
|
||||
|
||||
:func:`connect` can be used as an infinite asynchronous iterator to
|
||||
reconnect automatically on errors::
|
||||
|
||||
async for websocket in connect(...):
|
||||
try:
|
||||
...
|
||||
except websockets.ConnectionClosed:
|
||||
continue
|
||||
|
||||
If the connection fails with a transient error, it is retried with
|
||||
exponential backoff. If it fails with a fatal error, the exception is
|
||||
raised, breaking out of the loop.
|
||||
|
||||
The connection is closed automatically after each iteration of the loop.
|
||||
|
||||
Args:
|
||||
uri: URI of the WebSocket server.
|
||||
origin: Value of the ``Origin`` header, for servers that require it.
|
||||
extensions: List of supported extensions, in order in which they
|
||||
should be negotiated and run.
|
||||
subprotocols: List of supported subprotocols, in order of decreasing
|
||||
preference.
|
||||
additional_headers (HeadersLike | None): Arbitrary HTTP headers to add
|
||||
to the handshake request.
|
||||
user_agent_header: Value of the ``User-Agent`` request header.
|
||||
It defaults to ``"Python/x.y.z websockets/X.Y"``.
|
||||
Setting it to :obj:`None` removes the header.
|
||||
compression: The "permessage-deflate" extension is enabled by default.
|
||||
Set ``compression`` to :obj:`None` to disable it. See the
|
||||
:doc:`compression guide <../../topics/compression>` for details.
|
||||
process_exception: When reconnecting automatically, tell whether an
|
||||
error is transient or fatal. The default behavior is defined by
|
||||
:func:`process_exception`. Refer to its documentation for details.
|
||||
open_timeout: Timeout for opening the connection in seconds.
|
||||
:obj:`None` disables the timeout.
|
||||
ping_interval: Interval between keepalive pings in seconds.
|
||||
:obj:`None` disables keepalive.
|
||||
ping_timeout: Timeout for keepalive pings in seconds.
|
||||
:obj:`None` disables timeouts.
|
||||
close_timeout: Timeout for closing the connection in seconds.
|
||||
:obj:`None` disables the timeout.
|
||||
max_size: Maximum size of incoming messages in bytes.
|
||||
:obj:`None` disables the limit.
|
||||
max_queue: High-water mark of the buffer where frames are received.
|
||||
It defaults to 16 frames. The low-water mark defaults to ``max_queue
|
||||
// 4``. You may pass a ``(high, low)`` tuple to set the high-water
|
||||
and low-water marks.
|
||||
write_limit: High-water mark of write buffer in bytes. It is passed to
|
||||
:meth:`~asyncio.WriteTransport.set_write_buffer_limits`. It defaults
|
||||
to 32 KiB. You may pass a ``(high, low)`` tuple to set the
|
||||
high-water and low-water marks.
|
||||
logger: Logger for this client.
|
||||
It defaults to ``logging.getLogger("websockets.client")``.
|
||||
See the :doc:`logging guide <../../topics/logging>` for details.
|
||||
create_connection: Factory for the :class:`ClientConnection` managing
|
||||
the connection. Set it to a wrapper or a subclass to customize
|
||||
connection handling.
|
||||
|
||||
Any other keyword arguments are passed to the event loop's
|
||||
:meth:`~asyncio.loop.create_connection` method.
|
||||
|
||||
For example:
|
||||
|
||||
* You can set ``ssl`` to a :class:`~ssl.SSLContext` to enforce TLS settings.
|
||||
When connecting to a ``wss://`` URI, if ``ssl`` isn't provided, a TLS
|
||||
context is created with :func:`~ssl.create_default_context`.
|
||||
|
||||
* You can set ``server_hostname`` to override the host name from ``uri`` in
|
||||
the TLS handshake.
|
||||
|
||||
* You can set ``host`` and ``port`` to connect to a different host and port
|
||||
from those found in ``uri``. This only changes the destination of the TCP
|
||||
connection. The host name from ``uri`` is still used in the TLS handshake
|
||||
for secure connections and in the ``Host`` header.
|
||||
|
||||
* You can set ``sock`` to provide a preexisting TCP socket. You may call
|
||||
:func:`socket.create_connection` (not to be confused with the event loop's
|
||||
:meth:`~asyncio.loop.create_connection` method) to create a suitable
|
||||
client socket and customize it.
|
||||
|
||||
Raises:
|
||||
InvalidURI: If ``uri`` isn't a valid WebSocket URI.
|
||||
OSError: If the TCP connection fails.
|
||||
InvalidHandshake: If the opening handshake fails.
|
||||
TimeoutError: If the opening handshake times out.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
*,
|
||||
# WebSocket
|
||||
origin: Origin | None = None,
|
||||
extensions: Sequence[ClientExtensionFactory] | None = None,
|
||||
subprotocols: Sequence[Subprotocol] | None = None,
|
||||
additional_headers: HeadersLike | None = None,
|
||||
user_agent_header: str | None = USER_AGENT,
|
||||
compression: str | None = "deflate",
|
||||
process_exception: Callable[[Exception], Exception | None] = process_exception,
|
||||
# Timeouts
|
||||
open_timeout: float | None = 10,
|
||||
ping_interval: float | None = 20,
|
||||
ping_timeout: float | None = 20,
|
||||
close_timeout: float | None = 10,
|
||||
# Limits
|
||||
max_size: int | None = 2**20,
|
||||
max_queue: int | tuple[int, int | None] = 16,
|
||||
write_limit: int | tuple[int, int | None] = 2**15,
|
||||
# Logging
|
||||
logger: LoggerLike | None = None,
|
||||
# Escape hatch for advanced customization
|
||||
create_connection: type[ClientConnection] | None = None,
|
||||
# Other keyword arguments are passed to loop.create_connection
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.uri = uri
|
||||
|
||||
if subprotocols is not None:
|
||||
validate_subprotocols(subprotocols)
|
||||
|
||||
if compression == "deflate":
|
||||
extensions = enable_client_permessage_deflate(extensions)
|
||||
elif compression is not None:
|
||||
raise ValueError(f"unsupported compression: {compression}")
|
||||
|
||||
if logger is None:
|
||||
logger = logging.getLogger("websockets.client")
|
||||
|
||||
if create_connection is None:
|
||||
create_connection = ClientConnection
|
||||
|
||||
def protocol_factory(wsuri: WebSocketURI) -> ClientConnection:
|
||||
# This is a protocol in the Sans-I/O implementation of websockets.
|
||||
protocol = ClientProtocol(
|
||||
wsuri,
|
||||
origin=origin,
|
||||
extensions=extensions,
|
||||
subprotocols=subprotocols,
|
||||
max_size=max_size,
|
||||
logger=logger,
|
||||
)
|
||||
# This is a connection in websockets and a protocol in asyncio.
|
||||
connection = create_connection(
|
||||
protocol,
|
||||
ping_interval=ping_interval,
|
||||
ping_timeout=ping_timeout,
|
||||
close_timeout=close_timeout,
|
||||
max_queue=max_queue,
|
||||
write_limit=write_limit,
|
||||
)
|
||||
return connection
|
||||
|
||||
self.protocol_factory = protocol_factory
|
||||
self.handshake_args = (
|
||||
additional_headers,
|
||||
user_agent_header,
|
||||
)
|
||||
self.process_exception = process_exception
|
||||
self.open_timeout = open_timeout
|
||||
self.logger = logger
|
||||
self.connection_kwargs = kwargs
|
||||
|
||||
async def create_connection(self) -> ClientConnection:
|
||||
"""Create TCP or Unix connection."""
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
wsuri = parse_uri(self.uri)
|
||||
kwargs = self.connection_kwargs.copy()
|
||||
|
||||
def factory() -> ClientConnection:
|
||||
return self.protocol_factory(wsuri)
|
||||
|
||||
if wsuri.secure:
|
||||
kwargs.setdefault("ssl", True)
|
||||
kwargs.setdefault("server_hostname", wsuri.host)
|
||||
if kwargs.get("ssl") is None:
|
||||
raise TypeError("ssl=None is incompatible with a wss:// URI")
|
||||
else:
|
||||
if kwargs.get("ssl") is not None:
|
||||
raise TypeError("ssl argument is incompatible with a ws:// URI")
|
||||
|
||||
if kwargs.pop("unix", False):
|
||||
_, connection = await loop.create_unix_connection(factory, **kwargs)
|
||||
else:
|
||||
if kwargs.get("sock") is None:
|
||||
kwargs.setdefault("host", wsuri.host)
|
||||
kwargs.setdefault("port", wsuri.port)
|
||||
_, connection = await loop.create_connection(factory, **kwargs)
|
||||
return connection
|
||||
|
||||
def process_redirect(self, exc: Exception) -> Exception | str:
|
||||
"""
|
||||
Determine whether a connection error is a redirect that can be followed.
|
||||
|
||||
Return the new URI if it's a valid redirect. Else, return an exception.
|
||||
|
||||
"""
|
||||
if not (
|
||||
isinstance(exc, InvalidStatus)
|
||||
and exc.response.status_code
|
||||
in [
|
||||
300, # Multiple Choices
|
||||
301, # Moved Permanently
|
||||
302, # Found
|
||||
303, # See Other
|
||||
307, # Temporary Redirect
|
||||
308, # Permanent Redirect
|
||||
]
|
||||
and "Location" in exc.response.headers
|
||||
):
|
||||
return exc
|
||||
|
||||
old_wsuri = parse_uri(self.uri)
|
||||
new_uri = urllib.parse.urljoin(self.uri, exc.response.headers["Location"])
|
||||
new_wsuri = parse_uri(new_uri)
|
||||
|
||||
# If connect() received a socket, it is closed and cannot be reused.
|
||||
if self.connection_kwargs.get("sock") is not None:
|
||||
return ValueError(
|
||||
f"cannot follow redirect to {new_uri} with a preexisting socket"
|
||||
)
|
||||
|
||||
# TLS downgrade is forbidden.
|
||||
if old_wsuri.secure and not new_wsuri.secure:
|
||||
return SecurityError(f"cannot follow redirect to non-secure URI {new_uri}")
|
||||
|
||||
# Apply restrictions to cross-origin redirects.
|
||||
if (
|
||||
old_wsuri.secure != new_wsuri.secure
|
||||
or old_wsuri.host != new_wsuri.host
|
||||
or old_wsuri.port != new_wsuri.port
|
||||
):
|
||||
# Cross-origin redirects on Unix sockets don't quite make sense.
|
||||
if self.connection_kwargs.get("unix", False):
|
||||
return ValueError(
|
||||
f"cannot follow cross-origin redirect to {new_uri} "
|
||||
f"with a Unix socket"
|
||||
)
|
||||
|
||||
# Cross-origin redirects when host and port are overridden are ill-defined.
|
||||
if (
|
||||
self.connection_kwargs.get("host") is not None
|
||||
or self.connection_kwargs.get("port") is not None
|
||||
):
|
||||
return ValueError(
|
||||
f"cannot follow cross-origin redirect to {new_uri} "
|
||||
f"with an explicit host or port"
|
||||
)
|
||||
|
||||
return new_uri
|
||||
|
||||
# ... = await connect(...)
|
||||
|
||||
def __await__(self) -> Generator[Any, None, ClientConnection]:
|
||||
# Create a suitable iterator by calling __await__ on a coroutine.
|
||||
return self.__await_impl__().__await__()
|
||||
|
||||
async def __await_impl__(self) -> ClientConnection:
|
||||
try:
|
||||
async with asyncio_timeout(self.open_timeout):
|
||||
for _ in range(MAX_REDIRECTS):
|
||||
self.connection = await self.create_connection()
|
||||
try:
|
||||
await self.connection.handshake(*self.handshake_args)
|
||||
except asyncio.CancelledError:
|
||||
self.connection.close_transport()
|
||||
raise
|
||||
except Exception as exc:
|
||||
# Always close the connection even though keep-alive is
|
||||
# the default in HTTP/1.1 because create_connection ties
|
||||
# opening the network connection with initializing the
|
||||
# protocol. In the current design of connect(), there is
|
||||
# no easy way to reuse the network connection that works
|
||||
# in every case nor to reinitialize the protocol.
|
||||
self.connection.close_transport()
|
||||
|
||||
uri_or_exc = self.process_redirect(exc)
|
||||
# Response is a valid redirect; follow it.
|
||||
if isinstance(uri_or_exc, str):
|
||||
self.uri = uri_or_exc
|
||||
continue
|
||||
# Response isn't a valid redirect; raise the exception.
|
||||
if uri_or_exc is exc:
|
||||
raise
|
||||
else:
|
||||
raise uri_or_exc from exc
|
||||
|
||||
else:
|
||||
self.connection.start_keepalive()
|
||||
return self.connection
|
||||
else:
|
||||
raise SecurityError(f"more than {MAX_REDIRECTS} redirects")
|
||||
|
||||
except TimeoutError:
|
||||
# Re-raise exception with an informative error message.
|
||||
raise TimeoutError("timed out during handshake") from None
|
||||
|
||||
# ... = yield from connect(...) - remove when dropping Python < 3.10
|
||||
|
||||
__iter__ = __await__
|
||||
|
||||
# async with connect(...) as ...: ...
|
||||
|
||||
async def __aenter__(self) -> ClientConnection:
|
||||
return await self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
traceback: TracebackType | None,
|
||||
) -> None:
|
||||
await self.connection.close()
|
||||
|
||||
# async for ... in connect(...):
|
||||
|
||||
async def __aiter__(self) -> AsyncIterator[ClientConnection]:
|
||||
delays: Generator[float, None, None] | None = None
|
||||
while True:
|
||||
try:
|
||||
async with self as protocol:
|
||||
yield protocol
|
||||
except Exception as exc:
|
||||
# Determine whether the exception is retryable or fatal.
|
||||
# The API of process_exception is "return an exception or None";
|
||||
# "raise an exception" is also supported because it's a frequent
|
||||
# mistake. It isn't documented in order to keep the API simple.
|
||||
try:
|
||||
new_exc = self.process_exception(exc)
|
||||
except Exception as raised_exc:
|
||||
new_exc = raised_exc
|
||||
|
||||
# The connection failed with a fatal error.
|
||||
# Raise the exception and exit the loop.
|
||||
if new_exc is exc:
|
||||
raise
|
||||
if new_exc is not None:
|
||||
raise new_exc from exc
|
||||
|
||||
# The connection failed with a retryable error.
|
||||
# Start or continue backoff and reconnect.
|
||||
if delays is None:
|
||||
delays = backoff()
|
||||
delay = next(delays)
|
||||
self.logger.info(
|
||||
"! connect failed; reconnecting in %.1f seconds",
|
||||
delay,
|
||||
exc_info=True,
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
continue
|
||||
|
||||
else:
|
||||
# The connection succeeded. Reset backoff.
|
||||
delays = None
|
||||
|
||||
|
||||
def unix_connect(
|
||||
path: str | None = None,
|
||||
uri: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> connect:
|
||||
"""
|
||||
Connect to a WebSocket server listening on a Unix socket.
|
||||
|
||||
This function accepts the same keyword arguments as :func:`connect`.
|
||||
|
||||
It's only available on Unix.
|
||||
|
||||
It's mainly useful for debugging servers listening on Unix sockets.
|
||||
|
||||
Args:
|
||||
path: File system path to the Unix socket.
|
||||
uri: URI of the WebSocket server. ``uri`` defaults to
|
||||
``ws://localhost/`` or, when a ``ssl`` argument is provided, to
|
||||
``wss://localhost/``.
|
||||
|
||||
"""
|
||||
if uri is None:
|
||||
if kwargs.get("ssl") is None:
|
||||
uri = "ws://localhost/"
|
||||
else:
|
||||
uri = "wss://localhost/"
|
||||
return connect(uri=uri, unix=True, path=path, **kwargs)
|
@ -1,30 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
|
||||
|
||||
__all__ = ["TimeoutError", "aiter", "anext", "asyncio_timeout", "asyncio_timeout_at"]
|
||||
|
||||
|
||||
if sys.version_info[:2] >= (3, 11):
|
||||
TimeoutError = TimeoutError
|
||||
aiter = aiter
|
||||
anext = anext
|
||||
from asyncio import (
|
||||
timeout as asyncio_timeout, # noqa: F401
|
||||
timeout_at as asyncio_timeout_at, # noqa: F401
|
||||
)
|
||||
|
||||
else: # Python < 3.11
|
||||
from asyncio import TimeoutError
|
||||
|
||||
def aiter(async_iterable):
|
||||
return type(async_iterable).__aiter__(async_iterable)
|
||||
|
||||
async def anext(async_iterator):
|
||||
return await type(async_iterator).__anext__(async_iterator)
|
||||
|
||||
from .async_timeout import (
|
||||
timeout as asyncio_timeout, # noqa: F401
|
||||
timeout_at as asyncio_timeout_at, # noqa: F401
|
||||
)
|
File diff suppressed because it is too large
Load Diff
@ -1,293 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import codecs
|
||||
import collections
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Generic,
|
||||
Iterable,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
from ..exceptions import ConcurrencyError
|
||||
from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame
|
||||
from ..typing import Data
|
||||
|
||||
|
||||
__all__ = ["Assembler"]
|
||||
|
||||
UTF8Decoder = codecs.getincrementaldecoder("utf-8")
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class SimpleQueue(Generic[T]):
|
||||
"""
|
||||
Simplified version of :class:`asyncio.Queue`.
|
||||
|
||||
Provides only the subset of functionality needed by :class:`Assembler`.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.loop = asyncio.get_running_loop()
|
||||
self.get_waiter: asyncio.Future[None] | None = None
|
||||
self.queue: collections.deque[T] = collections.deque()
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.queue)
|
||||
|
||||
def put(self, item: T) -> None:
|
||||
"""Put an item into the queue without waiting."""
|
||||
self.queue.append(item)
|
||||
if self.get_waiter is not None and not self.get_waiter.done():
|
||||
self.get_waiter.set_result(None)
|
||||
|
||||
async def get(self) -> T:
|
||||
"""Remove and return an item from the queue, waiting if necessary."""
|
||||
if not self.queue:
|
||||
if self.get_waiter is not None:
|
||||
raise ConcurrencyError("get is already running")
|
||||
self.get_waiter = self.loop.create_future()
|
||||
try:
|
||||
await self.get_waiter
|
||||
finally:
|
||||
self.get_waiter.cancel()
|
||||
self.get_waiter = None
|
||||
return self.queue.popleft()
|
||||
|
||||
def reset(self, items: Iterable[T]) -> None:
|
||||
"""Put back items into an empty, idle queue."""
|
||||
assert self.get_waiter is None, "cannot reset() while get() is running"
|
||||
assert not self.queue, "cannot reset() while queue isn't empty"
|
||||
self.queue.extend(items)
|
||||
|
||||
def abort(self) -> None:
|
||||
if self.get_waiter is not None and not self.get_waiter.done():
|
||||
self.get_waiter.set_exception(EOFError("stream of frames ended"))
|
||||
# Clear the queue to avoid storing unnecessary data in memory.
|
||||
self.queue.clear()
|
||||
|
||||
|
||||
class Assembler:
|
||||
"""
|
||||
Assemble messages from frames.
|
||||
|
||||
:class:`Assembler` expects only data frames. The stream of frames must
|
||||
respect the protocol; if it doesn't, the behavior is undefined.
|
||||
|
||||
Args:
|
||||
pause: Called when the buffer of frames goes above the high water mark;
|
||||
should pause reading from the network.
|
||||
resume: Called when the buffer of frames goes below the low water mark;
|
||||
should resume reading from the network.
|
||||
|
||||
"""
|
||||
|
||||
# coverage reports incorrectly: "line NN didn't jump to the function exit"
|
||||
def __init__( # pragma: no cover
|
||||
self,
|
||||
high: int = 16,
|
||||
low: int | None = None,
|
||||
pause: Callable[[], Any] = lambda: None,
|
||||
resume: Callable[[], Any] = lambda: None,
|
||||
) -> None:
|
||||
# Queue of incoming messages. Each item is a queue of frames.
|
||||
self.frames: SimpleQueue[Frame] = SimpleQueue()
|
||||
|
||||
# We cannot put a hard limit on the size of the queue because a single
|
||||
# call to Protocol.data_received() could produce thousands of frames,
|
||||
# which must be buffered. Instead, we pause reading when the buffer goes
|
||||
# above the high limit and we resume when it goes under the low limit.
|
||||
if low is None:
|
||||
low = high // 4
|
||||
if low < 0:
|
||||
raise ValueError("low must be positive or equal to zero")
|
||||
if high < low:
|
||||
raise ValueError("high must be greater than or equal to low")
|
||||
self.high, self.low = high, low
|
||||
self.pause = pause
|
||||
self.resume = resume
|
||||
self.paused = False
|
||||
|
||||
# This flag prevents concurrent calls to get() by user code.
|
||||
self.get_in_progress = False
|
||||
|
||||
# This flag marks the end of the connection.
|
||||
self.closed = False
|
||||
|
||||
async def get(self, decode: bool | None = None) -> Data:
|
||||
"""
|
||||
Read the next message.
|
||||
|
||||
:meth:`get` returns a single :class:`str` or :class:`bytes`.
|
||||
|
||||
If the message is fragmented, :meth:`get` waits until the last frame is
|
||||
received, then it reassembles the message and returns it. To receive
|
||||
messages frame by frame, use :meth:`get_iter` instead.
|
||||
|
||||
Args:
|
||||
decode: :obj:`False` disables UTF-8 decoding of text frames and
|
||||
returns :class:`bytes`. :obj:`True` forces UTF-8 decoding of
|
||||
binary frames and returns :class:`str`.
|
||||
|
||||
Raises:
|
||||
EOFError: If the stream of frames has ended.
|
||||
ConcurrencyError: If two coroutines run :meth:`get` or
|
||||
:meth:`get_iter` concurrently.
|
||||
|
||||
"""
|
||||
if self.closed:
|
||||
raise EOFError("stream of frames ended")
|
||||
|
||||
if self.get_in_progress:
|
||||
raise ConcurrencyError("get() or get_iter() is already running")
|
||||
|
||||
# Locking with get_in_progress ensures only one coroutine can get here.
|
||||
self.get_in_progress = True
|
||||
|
||||
# First frame
|
||||
try:
|
||||
frame = await self.frames.get()
|
||||
except asyncio.CancelledError:
|
||||
self.get_in_progress = False
|
||||
raise
|
||||
self.maybe_resume()
|
||||
assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY
|
||||
if decode is None:
|
||||
decode = frame.opcode is OP_TEXT
|
||||
frames = [frame]
|
||||
|
||||
# Following frames, for fragmented messages
|
||||
while not frame.fin:
|
||||
try:
|
||||
frame = await self.frames.get()
|
||||
except asyncio.CancelledError:
|
||||
# Put frames already received back into the queue
|
||||
# so that future calls to get() can return them.
|
||||
self.frames.reset(frames)
|
||||
self.get_in_progress = False
|
||||
raise
|
||||
self.maybe_resume()
|
||||
assert frame.opcode is OP_CONT
|
||||
frames.append(frame)
|
||||
|
||||
self.get_in_progress = False
|
||||
|
||||
data = b"".join(frame.data for frame in frames)
|
||||
if decode:
|
||||
return data.decode()
|
||||
else:
|
||||
return data
|
||||
|
||||
async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]:
|
||||
"""
|
||||
Stream the next message.
|
||||
|
||||
Iterating the return value of :meth:`get_iter` asynchronously yields a
|
||||
:class:`str` or :class:`bytes` for each frame in the message.
|
||||
|
||||
The iterator must be fully consumed before calling :meth:`get_iter` or
|
||||
:meth:`get` again. Else, :exc:`ConcurrencyError` is raised.
|
||||
|
||||
This method only makes sense for fragmented messages. If messages aren't
|
||||
fragmented, use :meth:`get` instead.
|
||||
|
||||
Args:
|
||||
decode: :obj:`False` disables UTF-8 decoding of text frames and
|
||||
returns :class:`bytes`. :obj:`True` forces UTF-8 decoding of
|
||||
binary frames and returns :class:`str`.
|
||||
|
||||
Raises:
|
||||
EOFError: If the stream of frames has ended.
|
||||
ConcurrencyError: If two coroutines run :meth:`get` or
|
||||
:meth:`get_iter` concurrently.
|
||||
|
||||
"""
|
||||
if self.closed:
|
||||
raise EOFError("stream of frames ended")
|
||||
|
||||
if self.get_in_progress:
|
||||
raise ConcurrencyError("get() or get_iter() is already running")
|
||||
|
||||
# Locking with get_in_progress ensures only one coroutine can get here.
|
||||
self.get_in_progress = True
|
||||
|
||||
# First frame
|
||||
try:
|
||||
frame = await self.frames.get()
|
||||
except asyncio.CancelledError:
|
||||
self.get_in_progress = False
|
||||
raise
|
||||
self.maybe_resume()
|
||||
assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY
|
||||
if decode is None:
|
||||
decode = frame.opcode is OP_TEXT
|
||||
if decode:
|
||||
decoder = UTF8Decoder()
|
||||
yield decoder.decode(frame.data, frame.fin)
|
||||
else:
|
||||
yield frame.data
|
||||
|
||||
# Following frames, for fragmented messages
|
||||
while not frame.fin:
|
||||
# We cannot handle asyncio.CancelledError because we don't buffer
|
||||
# previous fragments — we're streaming them. Canceling get_iter()
|
||||
# here will leave the assembler in a stuck state. Future calls to
|
||||
# get() or get_iter() will raise ConcurrencyError.
|
||||
frame = await self.frames.get()
|
||||
self.maybe_resume()
|
||||
assert frame.opcode is OP_CONT
|
||||
if decode:
|
||||
yield decoder.decode(frame.data, frame.fin)
|
||||
else:
|
||||
yield frame.data
|
||||
|
||||
self.get_in_progress = False
|
||||
|
||||
def put(self, frame: Frame) -> None:
|
||||
"""
|
||||
Add ``frame`` to the next message.
|
||||
|
||||
Raises:
|
||||
EOFError: If the stream of frames has ended.
|
||||
|
||||
"""
|
||||
if self.closed:
|
||||
raise EOFError("stream of frames ended")
|
||||
|
||||
self.frames.put(frame)
|
||||
self.maybe_pause()
|
||||
|
||||
def maybe_pause(self) -> None:
|
||||
"""Pause the writer if queue is above the high water mark."""
|
||||
# Check for "> high" to support high = 0
|
||||
if len(self.frames) > self.high and not self.paused:
|
||||
self.paused = True
|
||||
self.pause()
|
||||
|
||||
def maybe_resume(self) -> None:
|
||||
"""Resume the writer if queue is below the low water mark."""
|
||||
# Check for "<= low" to support low = 0
|
||||
if len(self.frames) <= self.low and self.paused:
|
||||
self.paused = False
|
||||
self.resume()
|
||||
|
||||
def close(self) -> None:
|
||||
"""
|
||||
End the stream of frames.
|
||||
|
||||
Callling :meth:`close` concurrently with :meth:`get`, :meth:`get_iter`,
|
||||
or :meth:`put` is safe. They will raise :exc:`EOFError`.
|
||||
|
||||
"""
|
||||
if self.closed:
|
||||
return
|
||||
|
||||
self.closed = True
|
||||
|
||||
# Unblock get() or get_iter().
|
||||
self.frames.abort()
|
@ -1,973 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import hmac
|
||||
import http
|
||||
import logging
|
||||
import socket
|
||||
import sys
|
||||
from types import TracebackType
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Generator,
|
||||
Iterable,
|
||||
Sequence,
|
||||
Tuple,
|
||||
cast,
|
||||
)
|
||||
|
||||
from ..exceptions import InvalidHeader
|
||||
from ..extensions.base import ServerExtensionFactory
|
||||
from ..extensions.permessage_deflate import enable_server_permessage_deflate
|
||||
from ..frames import CloseCode
|
||||
from ..headers import (
|
||||
build_www_authenticate_basic,
|
||||
parse_authorization_basic,
|
||||
validate_subprotocols,
|
||||
)
|
||||
from ..http11 import SERVER, Request, Response
|
||||
from ..protocol import CONNECTING, OPEN, Event
|
||||
from ..server import ServerProtocol
|
||||
from ..typing import LoggerLike, Origin, StatusLike, Subprotocol
|
||||
from .compatibility import asyncio_timeout
|
||||
from .connection import Connection, broadcast
|
||||
|
||||
|
||||
__all__ = [
|
||||
"broadcast",
|
||||
"serve",
|
||||
"unix_serve",
|
||||
"ServerConnection",
|
||||
"Server",
|
||||
"basic_auth",
|
||||
]
|
||||
|
||||
|
||||
class ServerConnection(Connection):
|
||||
"""
|
||||
:mod:`asyncio` implementation of a WebSocket server connection.
|
||||
|
||||
:class:`ServerConnection` provides :meth:`recv` and :meth:`send` methods for
|
||||
receiving and sending messages.
|
||||
|
||||
It supports asynchronous iteration to receive messages::
|
||||
|
||||
async for message in websocket:
|
||||
await process(message)
|
||||
|
||||
The iterator exits normally when the connection is closed with close code
|
||||
1000 (OK) or 1001 (going away) or without a close code. It raises a
|
||||
:exc:`~websockets.exceptions.ConnectionClosedError` when the connection is
|
||||
closed with any other code.
|
||||
|
||||
The ``ping_interval``, ``ping_timeout``, ``close_timeout``, ``max_queue``,
|
||||
and ``write_limit`` arguments the same meaning as in :func:`serve`.
|
||||
|
||||
Args:
|
||||
protocol: Sans-I/O connection.
|
||||
server: Server that manages this connection.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
protocol: ServerProtocol,
|
||||
server: Server,
|
||||
*,
|
||||
ping_interval: float | None = 20,
|
||||
ping_timeout: float | None = 20,
|
||||
close_timeout: float | None = 10,
|
||||
max_queue: int | tuple[int, int | None] = 16,
|
||||
write_limit: int | tuple[int, int | None] = 2**15,
|
||||
) -> None:
|
||||
self.protocol: ServerProtocol
|
||||
super().__init__(
|
||||
protocol,
|
||||
ping_interval=ping_interval,
|
||||
ping_timeout=ping_timeout,
|
||||
close_timeout=close_timeout,
|
||||
max_queue=max_queue,
|
||||
write_limit=write_limit,
|
||||
)
|
||||
self.server = server
|
||||
self.request_rcvd: asyncio.Future[None] = self.loop.create_future()
|
||||
self.username: str # see basic_auth()
|
||||
|
||||
def respond(self, status: StatusLike, text: str) -> Response:
|
||||
"""
|
||||
Create a plain text HTTP response.
|
||||
|
||||
``process_request`` and ``process_response`` may call this method to
|
||||
return an HTTP response instead of performing the WebSocket opening
|
||||
handshake.
|
||||
|
||||
You can modify the response before returning it, for example by changing
|
||||
HTTP headers.
|
||||
|
||||
Args:
|
||||
status: HTTP status code.
|
||||
text: HTTP response body; it will be encoded to UTF-8.
|
||||
|
||||
Returns:
|
||||
HTTP response to send to the client.
|
||||
|
||||
"""
|
||||
return self.protocol.reject(status, text)
|
||||
|
||||
async def handshake(
|
||||
self,
|
||||
process_request: (
|
||||
Callable[
|
||||
[ServerConnection, Request],
|
||||
Awaitable[Response | None] | Response | None,
|
||||
]
|
||||
| None
|
||||
) = None,
|
||||
process_response: (
|
||||
Callable[
|
||||
[ServerConnection, Request, Response],
|
||||
Awaitable[Response | None] | Response | None,
|
||||
]
|
||||
| None
|
||||
) = None,
|
||||
server_header: str | None = SERVER,
|
||||
) -> None:
|
||||
"""
|
||||
Perform the opening handshake.
|
||||
|
||||
"""
|
||||
await asyncio.wait(
|
||||
[self.request_rcvd, self.connection_lost_waiter],
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
|
||||
if self.request is not None:
|
||||
async with self.send_context(expected_state=CONNECTING):
|
||||
response = None
|
||||
|
||||
if process_request is not None:
|
||||
try:
|
||||
response = process_request(self, self.request)
|
||||
if isinstance(response, Awaitable):
|
||||
response = await response
|
||||
except Exception as exc:
|
||||
self.protocol.handshake_exc = exc
|
||||
response = self.protocol.reject(
|
||||
http.HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||
(
|
||||
"Failed to open a WebSocket connection.\n"
|
||||
"See server log for more information.\n"
|
||||
),
|
||||
)
|
||||
|
||||
if response is None:
|
||||
if self.server.is_serving():
|
||||
self.response = self.protocol.accept(self.request)
|
||||
else:
|
||||
self.response = self.protocol.reject(
|
||||
http.HTTPStatus.SERVICE_UNAVAILABLE,
|
||||
"Server is shutting down.\n",
|
||||
)
|
||||
else:
|
||||
assert isinstance(response, Response) # help mypy
|
||||
self.response = response
|
||||
|
||||
if server_header:
|
||||
self.response.headers["Server"] = server_header
|
||||
|
||||
response = None
|
||||
|
||||
if process_response is not None:
|
||||
try:
|
||||
response = process_response(self, self.request, self.response)
|
||||
if isinstance(response, Awaitable):
|
||||
response = await response
|
||||
except Exception as exc:
|
||||
self.protocol.handshake_exc = exc
|
||||
response = self.protocol.reject(
|
||||
http.HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||
(
|
||||
"Failed to open a WebSocket connection.\n"
|
||||
"See server log for more information.\n"
|
||||
),
|
||||
)
|
||||
|
||||
if response is not None:
|
||||
assert isinstance(response, Response) # help mypy
|
||||
self.response = response
|
||||
|
||||
self.protocol.send_response(self.response)
|
||||
|
||||
# self.protocol.handshake_exc is always set when the connection is lost
|
||||
# before receiving a request, when the request cannot be parsed, when
|
||||
# the handshake encounters an error, or when process_request or
|
||||
# process_response sends a HTTP response that rejects the handshake.
|
||||
|
||||
if self.protocol.handshake_exc is not None:
|
||||
raise self.protocol.handshake_exc
|
||||
|
||||
def process_event(self, event: Event) -> None:
|
||||
"""
|
||||
Process one incoming event.
|
||||
|
||||
"""
|
||||
# First event - handshake request.
|
||||
if self.request is None:
|
||||
assert isinstance(event, Request)
|
||||
self.request = event
|
||||
self.request_rcvd.set_result(None)
|
||||
# Later events - frames.
|
||||
else:
|
||||
super().process_event(event)
|
||||
|
||||
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
||||
super().connection_made(transport)
|
||||
self.server.start_connection_handler(self)
|
||||
|
||||
|
||||
class Server:
|
||||
"""
|
||||
WebSocket server returned by :func:`serve`.
|
||||
|
||||
This class mirrors the API of :class:`asyncio.Server`.
|
||||
|
||||
It keeps track of WebSocket connections in order to close them properly
|
||||
when shutting down.
|
||||
|
||||
Args:
|
||||
handler: Connection handler. It receives the WebSocket connection,
|
||||
which is a :class:`ServerConnection`, in argument.
|
||||
process_request: Intercept the request during the opening handshake.
|
||||
Return an HTTP response to force the response. Return :obj:`None` to
|
||||
continue normally. When you force an HTTP 101 Continue response, the
|
||||
handshake is successful. Else, the connection is aborted.
|
||||
``process_request`` may be a function or a coroutine.
|
||||
process_response: Intercept the response during the opening handshake.
|
||||
Modify the response or return a new HTTP response to force the
|
||||
response. Return :obj:`None` to continue normally. When you force an
|
||||
HTTP 101 Continue response, the handshake is successful. Else, the
|
||||
connection is aborted. ``process_response`` may be a function or a
|
||||
coroutine.
|
||||
server_header: Value of the ``Server`` response header.
|
||||
It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to
|
||||
:obj:`None` removes the header.
|
||||
open_timeout: Timeout for opening connections in seconds.
|
||||
:obj:`None` disables the timeout.
|
||||
logger: Logger for this server.
|
||||
It defaults to ``logging.getLogger("websockets.server")``.
|
||||
See the :doc:`logging guide <../../topics/logging>` for details.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
handler: Callable[[ServerConnection], Awaitable[None]],
|
||||
*,
|
||||
process_request: (
|
||||
Callable[
|
||||
[ServerConnection, Request],
|
||||
Awaitable[Response | None] | Response | None,
|
||||
]
|
||||
| None
|
||||
) = None,
|
||||
process_response: (
|
||||
Callable[
|
||||
[ServerConnection, Request, Response],
|
||||
Awaitable[Response | None] | Response | None,
|
||||
]
|
||||
| None
|
||||
) = None,
|
||||
server_header: str | None = SERVER,
|
||||
open_timeout: float | None = 10,
|
||||
logger: LoggerLike | None = None,
|
||||
) -> None:
|
||||
self.loop = asyncio.get_running_loop()
|
||||
self.handler = handler
|
||||
self.process_request = process_request
|
||||
self.process_response = process_response
|
||||
self.server_header = server_header
|
||||
self.open_timeout = open_timeout
|
||||
if logger is None:
|
||||
logger = logging.getLogger("websockets.server")
|
||||
self.logger = logger
|
||||
|
||||
# Keep track of active connections.
|
||||
self.handlers: dict[ServerConnection, asyncio.Task[None]] = {}
|
||||
|
||||
# Task responsible for closing the server and terminating connections.
|
||||
self.close_task: asyncio.Task[None] | None = None
|
||||
|
||||
# Completed when the server is closed and connections are terminated.
|
||||
self.closed_waiter: asyncio.Future[None] = self.loop.create_future()
|
||||
|
||||
@property
|
||||
def connections(self) -> set[ServerConnection]:
|
||||
"""
|
||||
Set of active connections.
|
||||
|
||||
This property contains all connections that completed the opening
|
||||
handshake successfully and didn't start the closing handshake yet.
|
||||
It can be useful in combination with :func:`~broadcast`.
|
||||
|
||||
"""
|
||||
return {connection for connection in self.handlers if connection.state is OPEN}
|
||||
|
||||
def wrap(self, server: asyncio.Server) -> None:
|
||||
"""
|
||||
Attach to a given :class:`asyncio.Server`.
|
||||
|
||||
Since :meth:`~asyncio.loop.create_server` doesn't support injecting a
|
||||
custom ``Server`` class, the easiest solution that doesn't rely on
|
||||
private :mod:`asyncio` APIs is to:
|
||||
|
||||
- instantiate a :class:`Server`
|
||||
- give the protocol factory a reference to that instance
|
||||
- call :meth:`~asyncio.loop.create_server` with the factory
|
||||
- attach the resulting :class:`asyncio.Server` with this method
|
||||
|
||||
"""
|
||||
self.server = server
|
||||
for sock in server.sockets:
|
||||
if sock.family == socket.AF_INET:
|
||||
name = "%s:%d" % sock.getsockname()
|
||||
elif sock.family == socket.AF_INET6:
|
||||
name = "[%s]:%d" % sock.getsockname()[:2]
|
||||
elif sock.family == socket.AF_UNIX:
|
||||
name = sock.getsockname()
|
||||
# In the unlikely event that someone runs websockets over a
|
||||
# protocol other than IP or Unix sockets, avoid crashing.
|
||||
else: # pragma: no cover
|
||||
name = str(sock.getsockname())
|
||||
self.logger.info("server listening on %s", name)
|
||||
|
||||
async def conn_handler(self, connection: ServerConnection) -> None:
|
||||
"""
|
||||
Handle the lifecycle of a WebSocket connection.
|
||||
|
||||
Since this method doesn't have a caller that can handle exceptions,
|
||||
it attempts to log relevant ones.
|
||||
|
||||
It guarantees that the TCP connection is closed before exiting.
|
||||
|
||||
"""
|
||||
try:
|
||||
async with asyncio_timeout(self.open_timeout):
|
||||
try:
|
||||
await connection.handshake(
|
||||
self.process_request,
|
||||
self.process_response,
|
||||
self.server_header,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
connection.close_transport()
|
||||
raise
|
||||
except Exception:
|
||||
connection.logger.error("opening handshake failed", exc_info=True)
|
||||
connection.close_transport()
|
||||
return
|
||||
|
||||
assert connection.protocol.state is OPEN
|
||||
try:
|
||||
connection.start_keepalive()
|
||||
await self.handler(connection)
|
||||
except Exception:
|
||||
connection.logger.error("connection handler failed", exc_info=True)
|
||||
await connection.close(CloseCode.INTERNAL_ERROR)
|
||||
else:
|
||||
await connection.close()
|
||||
|
||||
except TimeoutError:
|
||||
# When the opening handshake times out, there's nothing to log.
|
||||
pass
|
||||
|
||||
except Exception: # pragma: no cover
|
||||
# Don't leak connections on unexpected errors.
|
||||
connection.transport.abort()
|
||||
|
||||
finally:
|
||||
# Registration is tied to the lifecycle of conn_handler() because
|
||||
# the server waits for connection handlers to terminate, even if
|
||||
# all connections are already closed.
|
||||
del self.handlers[connection]
|
||||
|
||||
def start_connection_handler(self, connection: ServerConnection) -> None:
|
||||
"""
|
||||
Register a connection with this server.
|
||||
|
||||
"""
|
||||
# The connection must be registered in self.handlers immediately.
|
||||
# If it was registered in conn_handler(), a race condition could
|
||||
# happen when closing the server after scheduling conn_handler()
|
||||
# but before it starts executing.
|
||||
self.handlers[connection] = self.loop.create_task(self.conn_handler(connection))
|
||||
|
||||
def close(self, close_connections: bool = True) -> None:
|
||||
"""
|
||||
Close the server.
|
||||
|
||||
* Close the underlying :class:`asyncio.Server`.
|
||||
* When ``close_connections`` is :obj:`True`, which is the default,
|
||||
close existing connections. Specifically:
|
||||
|
||||
* Reject opening WebSocket connections with an HTTP 503 (service
|
||||
unavailable) error. This happens when the server accepted the TCP
|
||||
connection but didn't complete the opening handshake before closing.
|
||||
* Close open WebSocket connections with close code 1001 (going away).
|
||||
|
||||
* Wait until all connection handlers terminate.
|
||||
|
||||
:meth:`close` is idempotent.
|
||||
|
||||
"""
|
||||
if self.close_task is None:
|
||||
self.close_task = self.get_loop().create_task(
|
||||
self._close(close_connections)
|
||||
)
|
||||
|
||||
async def _close(self, close_connections: bool) -> None:
|
||||
"""
|
||||
Implementation of :meth:`close`.
|
||||
|
||||
This calls :meth:`~asyncio.Server.close` on the underlying
|
||||
:class:`asyncio.Server` object to stop accepting new connections and
|
||||
then closes open connections with close code 1001.
|
||||
|
||||
"""
|
||||
self.logger.info("server closing")
|
||||
|
||||
# Stop accepting new connections.
|
||||
self.server.close()
|
||||
|
||||
# Wait until all accepted connections reach connection_made() and call
|
||||
# register(). See https://github.com/python/cpython/issues/79033 for
|
||||
# details. This workaround can be removed when dropping Python < 3.11.
|
||||
await asyncio.sleep(0)
|
||||
|
||||
if close_connections:
|
||||
# Close OPEN connections with close code 1001. After server.close(),
|
||||
# handshake() closes OPENING connections with an HTTP 503 error.
|
||||
close_tasks = [
|
||||
asyncio.create_task(connection.close(1001))
|
||||
for connection in self.handlers
|
||||
if connection.protocol.state is not CONNECTING
|
||||
]
|
||||
# asyncio.wait doesn't accept an empty first argument.
|
||||
if close_tasks:
|
||||
await asyncio.wait(close_tasks)
|
||||
|
||||
# Wait until all TCP connections are closed.
|
||||
await self.server.wait_closed()
|
||||
|
||||
# Wait until all connection handlers terminate.
|
||||
# asyncio.wait doesn't accept an empty first argument.
|
||||
if self.handlers:
|
||||
await asyncio.wait(self.handlers.values())
|
||||
|
||||
# Tell wait_closed() to return.
|
||||
self.closed_waiter.set_result(None)
|
||||
|
||||
self.logger.info("server closed")
|
||||
|
||||
async def wait_closed(self) -> None:
|
||||
"""
|
||||
Wait until the server is closed.
|
||||
|
||||
When :meth:`wait_closed` returns, all TCP connections are closed and
|
||||
all connection handlers have returned.
|
||||
|
||||
To ensure a fast shutdown, a connection handler should always be
|
||||
awaiting at least one of:
|
||||
|
||||
* :meth:`~ServerConnection.recv`: when the connection is closed,
|
||||
it raises :exc:`~websockets.exceptions.ConnectionClosedOK`;
|
||||
* :meth:`~ServerConnection.wait_closed`: when the connection is
|
||||
closed, it returns.
|
||||
|
||||
Then the connection handler is immediately notified of the shutdown;
|
||||
it can clean up and exit.
|
||||
|
||||
"""
|
||||
await asyncio.shield(self.closed_waiter)
|
||||
|
||||
def get_loop(self) -> asyncio.AbstractEventLoop:
|
||||
"""
|
||||
See :meth:`asyncio.Server.get_loop`.
|
||||
|
||||
"""
|
||||
return self.server.get_loop()
|
||||
|
||||
def is_serving(self) -> bool: # pragma: no cover
|
||||
"""
|
||||
See :meth:`asyncio.Server.is_serving`.
|
||||
|
||||
"""
|
||||
return self.server.is_serving()
|
||||
|
||||
async def start_serving(self) -> None: # pragma: no cover
|
||||
"""
|
||||
See :meth:`asyncio.Server.start_serving`.
|
||||
|
||||
Typical use::
|
||||
|
||||
server = await serve(..., start_serving=False)
|
||||
# perform additional setup here...
|
||||
# ... then start the server
|
||||
await server.start_serving()
|
||||
|
||||
"""
|
||||
await self.server.start_serving()
|
||||
|
||||
async def serve_forever(self) -> None: # pragma: no cover
|
||||
"""
|
||||
See :meth:`asyncio.Server.serve_forever`.
|
||||
|
||||
Typical use::
|
||||
|
||||
server = await serve(...)
|
||||
# this coroutine doesn't return
|
||||
# canceling it stops the server
|
||||
await server.serve_forever()
|
||||
|
||||
This is an alternative to using :func:`serve` as an asynchronous context
|
||||
manager. Shutdown is triggered by canceling :meth:`serve_forever`
|
||||
instead of exiting a :func:`serve` context.
|
||||
|
||||
"""
|
||||
await self.server.serve_forever()
|
||||
|
||||
@property
|
||||
def sockets(self) -> Iterable[socket.socket]:
|
||||
"""
|
||||
See :attr:`asyncio.Server.sockets`.
|
||||
|
||||
"""
|
||||
return self.server.sockets
|
||||
|
||||
async def __aenter__(self) -> Server: # pragma: no cover
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
traceback: TracebackType | None,
|
||||
) -> None: # pragma: no cover
|
||||
self.close()
|
||||
await self.wait_closed()
|
||||
|
||||
|
||||
# This is spelled in lower case because it's exposed as a callable in the API.
|
||||
class serve:
|
||||
"""
|
||||
Create a WebSocket server listening on ``host`` and ``port``.
|
||||
|
||||
Whenever a client connects, the server creates a :class:`ServerConnection`,
|
||||
performs the opening handshake, and delegates to the ``handler`` coroutine.
|
||||
|
||||
The handler receives the :class:`ServerConnection` instance, which you can
|
||||
use to send and receive messages.
|
||||
|
||||
Once the handler completes, either normally or with an exception, the server
|
||||
performs the closing handshake and closes the connection.
|
||||
|
||||
This coroutine returns a :class:`Server` whose API mirrors
|
||||
:class:`asyncio.Server`. Treat it as an asynchronous context manager to
|
||||
ensure that the server will be closed::
|
||||
|
||||
from websockets.asyncio.server import serve
|
||||
|
||||
def handler(websocket):
|
||||
...
|
||||
|
||||
# set this future to exit the server
|
||||
stop = asyncio.get_running_loop().create_future()
|
||||
|
||||
async with serve(handler, host, port):
|
||||
await stop
|
||||
|
||||
Alternatively, call :meth:`~Server.serve_forever` to serve requests and
|
||||
cancel it to stop the server::
|
||||
|
||||
server = await serve(handler, host, port)
|
||||
await server.serve_forever()
|
||||
|
||||
Args:
|
||||
handler: Connection handler. It receives the WebSocket connection,
|
||||
which is a :class:`ServerConnection`, in argument.
|
||||
host: Network interfaces the server binds to.
|
||||
See :meth:`~asyncio.loop.create_server` for details.
|
||||
port: TCP port the server listens on.
|
||||
See :meth:`~asyncio.loop.create_server` for details.
|
||||
origins: Acceptable values of the ``Origin`` header, for defending
|
||||
against Cross-Site WebSocket Hijacking attacks. Include :obj:`None`
|
||||
in the list if the lack of an origin is acceptable.
|
||||
extensions: List of supported extensions, in order in which they
|
||||
should be negotiated and run.
|
||||
subprotocols: List of supported subprotocols, in order of decreasing
|
||||
preference.
|
||||
select_subprotocol: Callback for selecting a subprotocol among
|
||||
those supported by the client and the server. It receives a
|
||||
:class:`ServerConnection` (not a
|
||||
:class:`~websockets.server.ServerProtocol`!) instance and a list of
|
||||
subprotocols offered by the client. Other than the first argument,
|
||||
it has the same behavior as the
|
||||
:meth:`ServerProtocol.select_subprotocol
|
||||
<websockets.server.ServerProtocol.select_subprotocol>` method.
|
||||
process_request: Intercept the request during the opening handshake.
|
||||
Return an HTTP response to force the response or :obj:`None` to
|
||||
continue normally. When you force an HTTP 101 Continue response, the
|
||||
handshake is successful. Else, the connection is aborted.
|
||||
``process_request`` may be a function or a coroutine.
|
||||
process_response: Intercept the response during the opening handshake.
|
||||
Return an HTTP response to force the response or :obj:`None` to
|
||||
continue normally. When you force an HTTP 101 Continue response, the
|
||||
handshake is successful. Else, the connection is aborted.
|
||||
``process_response`` may be a function or a coroutine.
|
||||
server_header: Value of the ``Server`` response header.
|
||||
It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to
|
||||
:obj:`None` removes the header.
|
||||
compression: The "permessage-deflate" extension is enabled by default.
|
||||
Set ``compression`` to :obj:`None` to disable it. See the
|
||||
:doc:`compression guide <../../topics/compression>` for details.
|
||||
open_timeout: Timeout for opening connections in seconds.
|
||||
:obj:`None` disables the timeout.
|
||||
ping_interval: Interval between keepalive pings in seconds.
|
||||
:obj:`None` disables keepalive.
|
||||
ping_timeout: Timeout for keepalive pings in seconds.
|
||||
:obj:`None` disables timeouts.
|
||||
close_timeout: Timeout for closing connections in seconds.
|
||||
:obj:`None` disables the timeout.
|
||||
max_size: Maximum size of incoming messages in bytes.
|
||||
:obj:`None` disables the limit.
|
||||
max_queue: High-water mark of the buffer where frames are received.
|
||||
It defaults to 16 frames. The low-water mark defaults to ``max_queue
|
||||
// 4``. You may pass a ``(high, low)`` tuple to set the high-water
|
||||
and low-water marks.
|
||||
write_limit: High-water mark of write buffer in bytes. It is passed to
|
||||
:meth:`~asyncio.WriteTransport.set_write_buffer_limits`. It defaults
|
||||
to 32 KiB. You may pass a ``(high, low)`` tuple to set the
|
||||
high-water and low-water marks.
|
||||
logger: Logger for this server.
|
||||
It defaults to ``logging.getLogger("websockets.server")``. See the
|
||||
:doc:`logging guide <../../topics/logging>` for details.
|
||||
create_connection: Factory for the :class:`ServerConnection` managing
|
||||
the connection. Set it to a wrapper or a subclass to customize
|
||||
connection handling.
|
||||
|
||||
Any other keyword arguments are passed to the event loop's
|
||||
:meth:`~asyncio.loop.create_server` method.
|
||||
|
||||
For example:
|
||||
|
||||
* You can set ``ssl`` to a :class:`~ssl.SSLContext` to enable TLS.
|
||||
|
||||
* You can set ``sock`` to provide a preexisting TCP socket. You may call
|
||||
:func:`socket.create_server` (not to be confused with the event loop's
|
||||
:meth:`~asyncio.loop.create_server` method) to create a suitable server
|
||||
socket and customize it.
|
||||
|
||||
* You can set ``start_serving`` to ``False`` to start accepting connections
|
||||
only after you call :meth:`~Server.start_serving()` or
|
||||
:meth:`~Server.serve_forever()`.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
handler: Callable[[ServerConnection], Awaitable[None]],
|
||||
host: str | None = None,
|
||||
port: int | None = None,
|
||||
*,
|
||||
# WebSocket
|
||||
origins: Sequence[Origin | None] | None = None,
|
||||
extensions: Sequence[ServerExtensionFactory] | None = None,
|
||||
subprotocols: Sequence[Subprotocol] | None = None,
|
||||
select_subprotocol: (
|
||||
Callable[
|
||||
[ServerConnection, Sequence[Subprotocol]],
|
||||
Subprotocol | None,
|
||||
]
|
||||
| None
|
||||
) = None,
|
||||
process_request: (
|
||||
Callable[
|
||||
[ServerConnection, Request],
|
||||
Awaitable[Response | None] | Response | None,
|
||||
]
|
||||
| None
|
||||
) = None,
|
||||
process_response: (
|
||||
Callable[
|
||||
[ServerConnection, Request, Response],
|
||||
Awaitable[Response | None] | Response | None,
|
||||
]
|
||||
| None
|
||||
) = None,
|
||||
server_header: str | None = SERVER,
|
||||
compression: str | None = "deflate",
|
||||
# Timeouts
|
||||
open_timeout: float | None = 10,
|
||||
ping_interval: float | None = 20,
|
||||
ping_timeout: float | None = 20,
|
||||
close_timeout: float | None = 10,
|
||||
# Limits
|
||||
max_size: int | None = 2**20,
|
||||
max_queue: int | tuple[int, int | None] = 16,
|
||||
write_limit: int | tuple[int, int | None] = 2**15,
|
||||
# Logging
|
||||
logger: LoggerLike | None = None,
|
||||
# Escape hatch for advanced customization
|
||||
create_connection: type[ServerConnection] | None = None,
|
||||
# Other keyword arguments are passed to loop.create_server
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
if subprotocols is not None:
|
||||
validate_subprotocols(subprotocols)
|
||||
|
||||
if compression == "deflate":
|
||||
extensions = enable_server_permessage_deflate(extensions)
|
||||
elif compression is not None:
|
||||
raise ValueError(f"unsupported compression: {compression}")
|
||||
|
||||
if create_connection is None:
|
||||
create_connection = ServerConnection
|
||||
|
||||
self.server = Server(
|
||||
handler,
|
||||
process_request=process_request,
|
||||
process_response=process_response,
|
||||
server_header=server_header,
|
||||
open_timeout=open_timeout,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
if kwargs.get("ssl") is not None:
|
||||
kwargs.setdefault("ssl_handshake_timeout", open_timeout)
|
||||
if sys.version_info[:2] >= (3, 11): # pragma: no branch
|
||||
kwargs.setdefault("ssl_shutdown_timeout", close_timeout)
|
||||
|
||||
def factory() -> ServerConnection:
|
||||
"""
|
||||
Create an asyncio protocol for managing a WebSocket connection.
|
||||
|
||||
"""
|
||||
# Create a closure to give select_subprotocol access to connection.
|
||||
protocol_select_subprotocol: (
|
||||
Callable[
|
||||
[ServerProtocol, Sequence[Subprotocol]],
|
||||
Subprotocol | None,
|
||||
]
|
||||
| None
|
||||
) = None
|
||||
if select_subprotocol is not None:
|
||||
|
||||
def protocol_select_subprotocol(
|
||||
protocol: ServerProtocol,
|
||||
subprotocols: Sequence[Subprotocol],
|
||||
) -> Subprotocol | None:
|
||||
# mypy doesn't know that select_subprotocol is immutable.
|
||||
assert select_subprotocol is not None
|
||||
# Ensure this function is only used in the intended context.
|
||||
assert protocol is connection.protocol
|
||||
return select_subprotocol(connection, subprotocols)
|
||||
|
||||
# This is a protocol in the Sans-I/O implementation of websockets.
|
||||
protocol = ServerProtocol(
|
||||
origins=origins,
|
||||
extensions=extensions,
|
||||
subprotocols=subprotocols,
|
||||
select_subprotocol=protocol_select_subprotocol,
|
||||
max_size=max_size,
|
||||
logger=logger,
|
||||
)
|
||||
# This is a connection in websockets and a protocol in asyncio.
|
||||
connection = create_connection(
|
||||
protocol,
|
||||
self.server,
|
||||
ping_interval=ping_interval,
|
||||
ping_timeout=ping_timeout,
|
||||
close_timeout=close_timeout,
|
||||
max_queue=max_queue,
|
||||
write_limit=write_limit,
|
||||
)
|
||||
return connection
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
if kwargs.pop("unix", False):
|
||||
self.create_server = loop.create_unix_server(factory, **kwargs)
|
||||
else:
|
||||
# mypy cannot tell that kwargs must provide sock when port is None.
|
||||
self.create_server = loop.create_server(factory, host, port, **kwargs) # type: ignore[arg-type]
|
||||
|
||||
# async with serve(...) as ...: ...
|
||||
|
||||
async def __aenter__(self) -> Server:
|
||||
return await self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
traceback: TracebackType | None,
|
||||
) -> None:
|
||||
self.server.close()
|
||||
await self.server.wait_closed()
|
||||
|
||||
# ... = await serve(...)
|
||||
|
||||
def __await__(self) -> Generator[Any, None, Server]:
|
||||
# Create a suitable iterator by calling __await__ on a coroutine.
|
||||
return self.__await_impl__().__await__()
|
||||
|
||||
async def __await_impl__(self) -> Server:
|
||||
server = await self.create_server
|
||||
self.server.wrap(server)
|
||||
return self.server
|
||||
|
||||
# ... = yield from serve(...) - remove when dropping Python < 3.10
|
||||
|
||||
__iter__ = __await__
|
||||
|
||||
|
||||
def unix_serve(
|
||||
handler: Callable[[ServerConnection], Awaitable[None]],
|
||||
path: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[Server]:
|
||||
"""
|
||||
Create a WebSocket server listening on a Unix socket.
|
||||
|
||||
This function is identical to :func:`serve`, except the ``host`` and
|
||||
``port`` arguments are replaced by ``path``. It's only available on Unix.
|
||||
|
||||
It's useful for deploying a server behind a reverse proxy such as nginx.
|
||||
|
||||
Args:
|
||||
handler: Connection handler. It receives the WebSocket connection,
|
||||
which is a :class:`ServerConnection`, in argument.
|
||||
path: File system path to the Unix socket.
|
||||
|
||||
"""
|
||||
return serve(handler, unix=True, path=path, **kwargs)
|
||||
|
||||
|
||||
def is_credentials(credentials: Any) -> bool:
|
||||
try:
|
||||
username, password = credentials
|
||||
except (TypeError, ValueError):
|
||||
return False
|
||||
else:
|
||||
return isinstance(username, str) and isinstance(password, str)
|
||||
|
||||
|
||||
def basic_auth(
|
||||
realm: str = "",
|
||||
credentials: tuple[str, str] | Iterable[tuple[str, str]] | None = None,
|
||||
check_credentials: Callable[[str, str], Awaitable[bool] | bool] | None = None,
|
||||
) -> Callable[[ServerConnection, Request], Awaitable[Response | None]]:
|
||||
"""
|
||||
Factory for ``process_request`` to enforce HTTP Basic Authentication.
|
||||
|
||||
:func:`basic_auth` is designed to integrate with :func:`serve` as follows::
|
||||
|
||||
from websockets.asyncio.server import basic_auth, serve
|
||||
|
||||
async with serve(
|
||||
...,
|
||||
process_request=basic_auth(
|
||||
realm="my dev server",
|
||||
credentials=("hello", "iloveyou"),
|
||||
),
|
||||
):
|
||||
|
||||
If authentication succeeds, the connection's ``username`` attribute is set.
|
||||
If it fails, the server responds with an HTTP 401 Unauthorized status.
|
||||
|
||||
One of ``credentials`` or ``check_credentials`` must be provided; not both.
|
||||
|
||||
Args:
|
||||
realm: Scope of protection. It should contain only ASCII characters
|
||||
because the encoding of non-ASCII characters is undefined. Refer to
|
||||
section 2.2 of :rfc:`7235` for details.
|
||||
credentials: Hard coded authorized credentials. It can be a
|
||||
``(username, password)`` pair or a list of such pairs.
|
||||
check_credentials: Function or coroutine that verifies credentials.
|
||||
It receives ``username`` and ``password`` arguments and returns
|
||||
whether they're valid.
|
||||
Raises:
|
||||
TypeError: If ``credentials`` or ``check_credentials`` is wrong.
|
||||
|
||||
"""
|
||||
if (credentials is None) == (check_credentials is None):
|
||||
raise TypeError("provide either credentials or check_credentials")
|
||||
|
||||
if credentials is not None:
|
||||
if is_credentials(credentials):
|
||||
credentials_list = [cast(Tuple[str, str], credentials)]
|
||||
elif isinstance(credentials, Iterable):
|
||||
credentials_list = list(cast(Iterable[Tuple[str, str]], credentials))
|
||||
if not all(is_credentials(item) for item in credentials_list):
|
||||
raise TypeError(f"invalid credentials argument: {credentials}")
|
||||
else:
|
||||
raise TypeError(f"invalid credentials argument: {credentials}")
|
||||
|
||||
credentials_dict = dict(credentials_list)
|
||||
|
||||
def check_credentials(username: str, password: str) -> bool:
|
||||
try:
|
||||
expected_password = credentials_dict[username]
|
||||
except KeyError:
|
||||
return False
|
||||
return hmac.compare_digest(expected_password, password)
|
||||
|
||||
assert check_credentials is not None # help mypy
|
||||
|
||||
async def process_request(
|
||||
connection: ServerConnection,
|
||||
request: Request,
|
||||
) -> Response | None:
|
||||
"""
|
||||
Perform HTTP Basic Authentication.
|
||||
|
||||
If it succeeds, set the connection's ``username`` attribute and return
|
||||
:obj:`None`. If it fails, return an HTTP 401 Unauthorized responss.
|
||||
|
||||
"""
|
||||
try:
|
||||
authorization = request.headers["Authorization"]
|
||||
except KeyError:
|
||||
response = connection.respond(
|
||||
http.HTTPStatus.UNAUTHORIZED,
|
||||
"Missing credentials\n",
|
||||
)
|
||||
response.headers["WWW-Authenticate"] = build_www_authenticate_basic(realm)
|
||||
return response
|
||||
|
||||
try:
|
||||
username, password = parse_authorization_basic(authorization)
|
||||
except InvalidHeader:
|
||||
response = connection.respond(
|
||||
http.HTTPStatus.UNAUTHORIZED,
|
||||
"Unsupported credentials\n",
|
||||
)
|
||||
response.headers["WWW-Authenticate"] = build_www_authenticate_basic(realm)
|
||||
return response
|
||||
|
||||
valid_credentials = check_credentials(username, password)
|
||||
if isinstance(valid_credentials, Awaitable):
|
||||
valid_credentials = await valid_credentials
|
||||
|
||||
if not valid_credentials:
|
||||
response = connection.respond(
|
||||
http.HTTPStatus.UNAUTHORIZED,
|
||||
"Invalid credentials\n",
|
||||
)
|
||||
response.headers["WWW-Authenticate"] = build_www_authenticate_basic(realm)
|
||||
return response
|
||||
|
||||
connection.username = username
|
||||
return None
|
||||
|
||||
return process_request
|
@ -1,6 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
# See #940 for why lazy_import isn't used here for backwards compatibility.
|
||||
# See #1400 for why listing compatibility imports in __all__ helps PyCharm.
|
||||
from .legacy.auth import *
|
||||
from .legacy.auth import __all__ # noqa: F401
|
@ -1,393 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import random
|
||||
import warnings
|
||||
from typing import Any, Generator, Sequence
|
||||
|
||||
from .datastructures import Headers, MultipleValuesError
|
||||
from .exceptions import (
|
||||
InvalidHandshake,
|
||||
InvalidHeader,
|
||||
InvalidHeaderValue,
|
||||
InvalidStatus,
|
||||
InvalidUpgrade,
|
||||
NegotiationError,
|
||||
)
|
||||
from .extensions import ClientExtensionFactory, Extension
|
||||
from .headers import (
|
||||
build_authorization_basic,
|
||||
build_extension,
|
||||
build_host,
|
||||
build_subprotocol,
|
||||
parse_connection,
|
||||
parse_extension,
|
||||
parse_subprotocol,
|
||||
parse_upgrade,
|
||||
)
|
||||
from .http11 import Request, Response
|
||||
from .protocol import CLIENT, CONNECTING, OPEN, Protocol, State
|
||||
from .typing import (
|
||||
ConnectionOption,
|
||||
ExtensionHeader,
|
||||
LoggerLike,
|
||||
Origin,
|
||||
Subprotocol,
|
||||
UpgradeProtocol,
|
||||
)
|
||||
from .uri import WebSocketURI
|
||||
from .utils import accept_key, generate_key
|
||||
|
||||
|
||||
# See #940 for why lazy_import isn't used here for backwards compatibility.
|
||||
# See #1400 for why listing compatibility imports in __all__ helps PyCharm.
|
||||
from .legacy.client import * # isort:skip # noqa: I001
|
||||
from .legacy.client import __all__ as legacy__all__
|
||||
|
||||
|
||||
__all__ = ["ClientProtocol"] + legacy__all__
|
||||
|
||||
|
||||
class ClientProtocol(Protocol):
|
||||
"""
|
||||
Sans-I/O implementation of a WebSocket client connection.
|
||||
|
||||
Args:
|
||||
wsuri: URI of the WebSocket server, parsed
|
||||
with :func:`~websockets.uri.parse_uri`.
|
||||
origin: Value of the ``Origin`` header. This is useful when connecting
|
||||
to a server that validates the ``Origin`` header to defend against
|
||||
Cross-Site WebSocket Hijacking attacks.
|
||||
extensions: List of supported extensions, in order in which they
|
||||
should be tried.
|
||||
subprotocols: List of supported subprotocols, in order of decreasing
|
||||
preference.
|
||||
state: Initial state of the WebSocket connection.
|
||||
max_size: Maximum size of incoming messages in bytes;
|
||||
:obj:`None` disables the limit.
|
||||
logger: Logger for this connection;
|
||||
defaults to ``logging.getLogger("websockets.client")``;
|
||||
see the :doc:`logging guide <../../topics/logging>` for details.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
wsuri: WebSocketURI,
|
||||
*,
|
||||
origin: Origin | None = None,
|
||||
extensions: Sequence[ClientExtensionFactory] | None = None,
|
||||
subprotocols: Sequence[Subprotocol] | None = None,
|
||||
state: State = CONNECTING,
|
||||
max_size: int | None = 2**20,
|
||||
logger: LoggerLike | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
side=CLIENT,
|
||||
state=state,
|
||||
max_size=max_size,
|
||||
logger=logger,
|
||||
)
|
||||
self.wsuri = wsuri
|
||||
self.origin = origin
|
||||
self.available_extensions = extensions
|
||||
self.available_subprotocols = subprotocols
|
||||
self.key = generate_key()
|
||||
|
||||
def connect(self) -> Request:
|
||||
"""
|
||||
Create a handshake request to open a connection.
|
||||
|
||||
You must send the handshake request with :meth:`send_request`.
|
||||
|
||||
You can modify it before sending it, for example to add HTTP headers.
|
||||
|
||||
Returns:
|
||||
WebSocket handshake request event to send to the server.
|
||||
|
||||
"""
|
||||
headers = Headers()
|
||||
|
||||
headers["Host"] = build_host(
|
||||
self.wsuri.host, self.wsuri.port, self.wsuri.secure
|
||||
)
|
||||
|
||||
if self.wsuri.user_info:
|
||||
headers["Authorization"] = build_authorization_basic(*self.wsuri.user_info)
|
||||
|
||||
if self.origin is not None:
|
||||
headers["Origin"] = self.origin
|
||||
|
||||
headers["Upgrade"] = "websocket"
|
||||
headers["Connection"] = "Upgrade"
|
||||
headers["Sec-WebSocket-Key"] = self.key
|
||||
headers["Sec-WebSocket-Version"] = "13"
|
||||
|
||||
if self.available_extensions is not None:
|
||||
extensions_header = build_extension(
|
||||
[
|
||||
(extension_factory.name, extension_factory.get_request_params())
|
||||
for extension_factory in self.available_extensions
|
||||
]
|
||||
)
|
||||
headers["Sec-WebSocket-Extensions"] = extensions_header
|
||||
|
||||
if self.available_subprotocols is not None:
|
||||
protocol_header = build_subprotocol(self.available_subprotocols)
|
||||
headers["Sec-WebSocket-Protocol"] = protocol_header
|
||||
|
||||
return Request(self.wsuri.resource_name, headers)
|
||||
|
||||
def process_response(self, response: Response) -> None:
|
||||
"""
|
||||
Check a handshake response.
|
||||
|
||||
Args:
|
||||
request: WebSocket handshake response received from the server.
|
||||
|
||||
Raises:
|
||||
InvalidHandshake: If the handshake response is invalid.
|
||||
|
||||
"""
|
||||
|
||||
if response.status_code != 101:
|
||||
raise InvalidStatus(response)
|
||||
|
||||
headers = response.headers
|
||||
|
||||
connection: list[ConnectionOption] = sum(
|
||||
[parse_connection(value) for value in headers.get_all("Connection")], []
|
||||
)
|
||||
|
||||
if not any(value.lower() == "upgrade" for value in connection):
|
||||
raise InvalidUpgrade(
|
||||
"Connection", ", ".join(connection) if connection else None
|
||||
)
|
||||
|
||||
upgrade: list[UpgradeProtocol] = sum(
|
||||
[parse_upgrade(value) for value in headers.get_all("Upgrade")], []
|
||||
)
|
||||
|
||||
# For compatibility with non-strict implementations, ignore case when
|
||||
# checking the Upgrade header. It's supposed to be 'WebSocket'.
|
||||
if not (len(upgrade) == 1 and upgrade[0].lower() == "websocket"):
|
||||
raise InvalidUpgrade("Upgrade", ", ".join(upgrade) if upgrade else None)
|
||||
|
||||
try:
|
||||
s_w_accept = headers["Sec-WebSocket-Accept"]
|
||||
except KeyError as exc:
|
||||
raise InvalidHeader("Sec-WebSocket-Accept") from exc
|
||||
except MultipleValuesError as exc:
|
||||
raise InvalidHeader("Sec-WebSocket-Accept", "multiple values") from exc
|
||||
|
||||
if s_w_accept != accept_key(self.key):
|
||||
raise InvalidHeaderValue("Sec-WebSocket-Accept", s_w_accept)
|
||||
|
||||
self.extensions = self.process_extensions(headers)
|
||||
|
||||
self.subprotocol = self.process_subprotocol(headers)
|
||||
|
||||
def process_extensions(self, headers: Headers) -> list[Extension]:
|
||||
"""
|
||||
Handle the Sec-WebSocket-Extensions HTTP response header.
|
||||
|
||||
Check that each extension is supported, as well as its parameters.
|
||||
|
||||
:rfc:`6455` leaves the rules up to the specification of each
|
||||
extension.
|
||||
|
||||
To provide this level of flexibility, for each extension accepted by
|
||||
the server, we check for a match with each extension available in the
|
||||
client configuration. If no match is found, an exception is raised.
|
||||
|
||||
If several variants of the same extension are accepted by the server,
|
||||
it may be configured several times, which won't make sense in general.
|
||||
Extensions must implement their own requirements. For this purpose,
|
||||
the list of previously accepted extensions is provided.
|
||||
|
||||
Other requirements, for example related to mandatory extensions or the
|
||||
order of extensions, may be implemented by overriding this method.
|
||||
|
||||
Args:
|
||||
headers: WebSocket handshake response headers.
|
||||
|
||||
Returns:
|
||||
List of accepted extensions.
|
||||
|
||||
Raises:
|
||||
InvalidHandshake: To abort the handshake.
|
||||
|
||||
"""
|
||||
accepted_extensions: list[Extension] = []
|
||||
|
||||
extensions = headers.get_all("Sec-WebSocket-Extensions")
|
||||
|
||||
if extensions:
|
||||
if self.available_extensions is None:
|
||||
raise NegotiationError("no extensions supported")
|
||||
|
||||
parsed_extensions: list[ExtensionHeader] = sum(
|
||||
[parse_extension(header_value) for header_value in extensions], []
|
||||
)
|
||||
|
||||
for name, response_params in parsed_extensions:
|
||||
for extension_factory in self.available_extensions:
|
||||
# Skip non-matching extensions based on their name.
|
||||
if extension_factory.name != name:
|
||||
continue
|
||||
|
||||
# Skip non-matching extensions based on their params.
|
||||
try:
|
||||
extension = extension_factory.process_response_params(
|
||||
response_params, accepted_extensions
|
||||
)
|
||||
except NegotiationError:
|
||||
continue
|
||||
|
||||
# Add matching extension to the final list.
|
||||
accepted_extensions.append(extension)
|
||||
|
||||
# Break out of the loop once we have a match.
|
||||
break
|
||||
|
||||
# If we didn't break from the loop, no extension in our list
|
||||
# matched what the server sent. Fail the connection.
|
||||
else:
|
||||
raise NegotiationError(
|
||||
f"Unsupported extension: "
|
||||
f"name = {name}, params = {response_params}"
|
||||
)
|
||||
|
||||
return accepted_extensions
|
||||
|
||||
def process_subprotocol(self, headers: Headers) -> Subprotocol | None:
|
||||
"""
|
||||
Handle the Sec-WebSocket-Protocol HTTP response header.
|
||||
|
||||
If provided, check that it contains exactly one supported subprotocol.
|
||||
|
||||
Args:
|
||||
headers: WebSocket handshake response headers.
|
||||
|
||||
Returns:
|
||||
Subprotocol, if one was selected.
|
||||
|
||||
"""
|
||||
subprotocol: Subprotocol | None = None
|
||||
|
||||
subprotocols = headers.get_all("Sec-WebSocket-Protocol")
|
||||
|
||||
if subprotocols:
|
||||
if self.available_subprotocols is None:
|
||||
raise NegotiationError("no subprotocols supported")
|
||||
|
||||
parsed_subprotocols: Sequence[Subprotocol] = sum(
|
||||
[parse_subprotocol(header_value) for header_value in subprotocols], []
|
||||
)
|
||||
|
||||
if len(parsed_subprotocols) > 1:
|
||||
raise InvalidHeader(
|
||||
"Sec-WebSocket-Protocol",
|
||||
f"multiple values: {', '.join(parsed_subprotocols)}",
|
||||
)
|
||||
|
||||
subprotocol = parsed_subprotocols[0]
|
||||
|
||||
if subprotocol not in self.available_subprotocols:
|
||||
raise NegotiationError(f"unsupported subprotocol: {subprotocol}")
|
||||
|
||||
return subprotocol
|
||||
|
||||
def send_request(self, request: Request) -> None:
|
||||
"""
|
||||
Send a handshake request to the server.
|
||||
|
||||
Args:
|
||||
request: WebSocket handshake request event.
|
||||
|
||||
"""
|
||||
if self.debug:
|
||||
self.logger.debug("> GET %s HTTP/1.1", request.path)
|
||||
for key, value in request.headers.raw_items():
|
||||
self.logger.debug("> %s: %s", key, value)
|
||||
|
||||
self.writes.append(request.serialize())
|
||||
|
||||
def parse(self) -> Generator[None, None, None]:
|
||||
if self.state is CONNECTING:
|
||||
try:
|
||||
response = yield from Response.parse(
|
||||
self.reader.read_line,
|
||||
self.reader.read_exact,
|
||||
self.reader.read_to_eof,
|
||||
)
|
||||
except Exception as exc:
|
||||
self.handshake_exc = exc
|
||||
self.send_eof()
|
||||
self.parser = self.discard()
|
||||
next(self.parser) # start coroutine
|
||||
yield
|
||||
|
||||
if self.debug:
|
||||
code, phrase = response.status_code, response.reason_phrase
|
||||
self.logger.debug("< HTTP/1.1 %d %s", code, phrase)
|
||||
for key, value in response.headers.raw_items():
|
||||
self.logger.debug("< %s: %s", key, value)
|
||||
if response.body is not None:
|
||||
self.logger.debug("< [body] (%d bytes)", len(response.body))
|
||||
|
||||
try:
|
||||
self.process_response(response)
|
||||
except InvalidHandshake as exc:
|
||||
response._exception = exc
|
||||
self.events.append(response)
|
||||
self.handshake_exc = exc
|
||||
self.send_eof()
|
||||
self.parser = self.discard()
|
||||
next(self.parser) # start coroutine
|
||||
yield
|
||||
|
||||
assert self.state is CONNECTING
|
||||
self.state = OPEN
|
||||
self.events.append(response)
|
||||
|
||||
yield from super().parse()
|
||||
|
||||
|
||||
class ClientConnection(ClientProtocol):
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
warnings.warn( # deprecated in 11.0 - 2023-04-02
|
||||
"ClientConnection was renamed to ClientProtocol",
|
||||
DeprecationWarning,
|
||||
)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
BACKOFF_INITIAL_DELAY = float(os.environ.get("WEBSOCKETS_BACKOFF_INITIAL_DELAY", "5"))
|
||||
BACKOFF_MIN_DELAY = float(os.environ.get("WEBSOCKETS_BACKOFF_MIN_DELAY", "3.1"))
|
||||
BACKOFF_MAX_DELAY = float(os.environ.get("WEBSOCKETS_BACKOFF_MAX_DELAY", "90.0"))
|
||||
BACKOFF_FACTOR = float(os.environ.get("WEBSOCKETS_BACKOFF_FACTOR", "1.618"))
|
||||
|
||||
|
||||
def backoff(
|
||||
initial_delay: float = BACKOFF_INITIAL_DELAY,
|
||||
min_delay: float = BACKOFF_MIN_DELAY,
|
||||
max_delay: float = BACKOFF_MAX_DELAY,
|
||||
factor: float = BACKOFF_FACTOR,
|
||||
) -> Generator[float, None, None]:
|
||||
"""
|
||||
Generate a series of backoff delays between reconnection attempts.
|
||||
|
||||
Yields:
|
||||
How many seconds to wait before retrying to connect.
|
||||
|
||||
"""
|
||||
# Add a random initial delay between 0 and 5 seconds.
|
||||
# See 7.2.3. Recovering from Abnormal Closure in RFC 6455.
|
||||
yield random.random() * initial_delay
|
||||
delay = min_delay
|
||||
while delay < max_delay:
|
||||
yield delay
|
||||
delay *= factor
|
||||
while True:
|
||||
yield max_delay
|
@ -1,12 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
|
||||
from .protocol import SEND_EOF, Protocol as Connection, Side, State # noqa: F401
|
||||
|
||||
|
||||
warnings.warn( # deprecated in 11.0 - 2023-04-02
|
||||
"websockets.connection was renamed to websockets.protocol "
|
||||
"and Connection was renamed to Protocol",
|
||||
DeprecationWarning,
|
||||
)
|
@ -1,192 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import (
|
||||
Any,
|
||||
Iterable,
|
||||
Iterator,
|
||||
Mapping,
|
||||
MutableMapping,
|
||||
Protocol,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["Headers", "HeadersLike", "MultipleValuesError"]
|
||||
|
||||
|
||||
class MultipleValuesError(LookupError):
|
||||
"""
|
||||
Exception raised when :class:`Headers` has multiple values for a key.
|
||||
|
||||
"""
|
||||
|
||||
def __str__(self) -> str:
|
||||
# Implement the same logic as KeyError_str in Objects/exceptions.c.
|
||||
if len(self.args) == 1:
|
||||
return repr(self.args[0])
|
||||
return super().__str__()
|
||||
|
||||
|
||||
class Headers(MutableMapping[str, str]):
|
||||
"""
|
||||
Efficient data structure for manipulating HTTP headers.
|
||||
|
||||
A :class:`list` of ``(name, values)`` is inefficient for lookups.
|
||||
|
||||
A :class:`dict` doesn't suffice because header names are case-insensitive
|
||||
and multiple occurrences of headers with the same name are possible.
|
||||
|
||||
:class:`Headers` stores HTTP headers in a hybrid data structure to provide
|
||||
efficient insertions and lookups while preserving the original data.
|
||||
|
||||
In order to account for multiple values with minimal hassle,
|
||||
:class:`Headers` follows this logic:
|
||||
|
||||
- When getting a header with ``headers[name]``:
|
||||
- if there's no value, :exc:`KeyError` is raised;
|
||||
- if there's exactly one value, it's returned;
|
||||
- if there's more than one value, :exc:`MultipleValuesError` is raised.
|
||||
|
||||
- When setting a header with ``headers[name] = value``, the value is
|
||||
appended to the list of values for that header.
|
||||
|
||||
- When deleting a header with ``del headers[name]``, all values for that
|
||||
header are removed (this is slow).
|
||||
|
||||
Other methods for manipulating headers are consistent with this logic.
|
||||
|
||||
As long as no header occurs multiple times, :class:`Headers` behaves like
|
||||
:class:`dict`, except keys are lower-cased to provide case-insensitivity.
|
||||
|
||||
Two methods support manipulating multiple values explicitly:
|
||||
|
||||
- :meth:`get_all` returns a list of all values for a header;
|
||||
- :meth:`raw_items` returns an iterator of ``(name, values)`` pairs.
|
||||
|
||||
"""
|
||||
|
||||
__slots__ = ["_dict", "_list"]
|
||||
|
||||
# Like dict, Headers accepts an optional "mapping or iterable" argument.
|
||||
def __init__(self, *args: HeadersLike, **kwargs: str) -> None:
|
||||
self._dict: dict[str, list[str]] = {}
|
||||
self._list: list[tuple[str, str]] = []
|
||||
self.update(*args, **kwargs)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "".join(f"{key}: {value}\r\n" for key, value in self._list) + "\r\n"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}({self._list!r})"
|
||||
|
||||
def copy(self) -> Headers:
|
||||
copy = self.__class__()
|
||||
copy._dict = self._dict.copy()
|
||||
copy._list = self._list.copy()
|
||||
return copy
|
||||
|
||||
def serialize(self) -> bytes:
|
||||
# Since headers only contain ASCII characters, we can keep this simple.
|
||||
return str(self).encode()
|
||||
|
||||
# Collection methods
|
||||
|
||||
def __contains__(self, key: object) -> bool:
|
||||
return isinstance(key, str) and key.lower() in self._dict
|
||||
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
return iter(self._dict)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._dict)
|
||||
|
||||
# MutableMapping methods
|
||||
|
||||
def __getitem__(self, key: str) -> str:
|
||||
value = self._dict[key.lower()]
|
||||
if len(value) == 1:
|
||||
return value[0]
|
||||
else:
|
||||
raise MultipleValuesError(key)
|
||||
|
||||
def __setitem__(self, key: str, value: str) -> None:
|
||||
self._dict.setdefault(key.lower(), []).append(value)
|
||||
self._list.append((key, value))
|
||||
|
||||
def __delitem__(self, key: str) -> None:
|
||||
key_lower = key.lower()
|
||||
self._dict.__delitem__(key_lower)
|
||||
# This is inefficient. Fortunately deleting HTTP headers is uncommon.
|
||||
self._list = [(k, v) for k, v in self._list if k.lower() != key_lower]
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if not isinstance(other, Headers):
|
||||
return NotImplemented
|
||||
return self._dict == other._dict
|
||||
|
||||
def clear(self) -> None:
|
||||
"""
|
||||
Remove all headers.
|
||||
|
||||
"""
|
||||
self._dict = {}
|
||||
self._list = []
|
||||
|
||||
def update(self, *args: HeadersLike, **kwargs: str) -> None:
|
||||
"""
|
||||
Update from a :class:`Headers` instance and/or keyword arguments.
|
||||
|
||||
"""
|
||||
args = tuple(
|
||||
arg.raw_items() if isinstance(arg, Headers) else arg for arg in args
|
||||
)
|
||||
super().update(*args, **kwargs)
|
||||
|
||||
# Methods for handling multiple values
|
||||
|
||||
def get_all(self, key: str) -> list[str]:
|
||||
"""
|
||||
Return the (possibly empty) list of all values for a header.
|
||||
|
||||
Args:
|
||||
key: Header name.
|
||||
|
||||
"""
|
||||
return self._dict.get(key.lower(), [])
|
||||
|
||||
def raw_items(self) -> Iterator[tuple[str, str]]:
|
||||
"""
|
||||
Return an iterator of all values as ``(name, value)`` pairs.
|
||||
|
||||
"""
|
||||
return iter(self._list)
|
||||
|
||||
|
||||
# copy of _typeshed.SupportsKeysAndGetItem.
|
||||
class SupportsKeysAndGetItem(Protocol): # pragma: no cover
|
||||
"""
|
||||
Dict-like types with ``keys() -> str`` and ``__getitem__(key: str) -> str`` methods.
|
||||
|
||||
"""
|
||||
|
||||
def keys(self) -> Iterable[str]: ...
|
||||
|
||||
def __getitem__(self, key: str) -> str: ...
|
||||
|
||||
|
||||
# Change to Headers | Mapping[str, str] | ... when dropping Python < 3.10.
|
||||
HeadersLike = Union[
|
||||
Headers,
|
||||
Mapping[str, str],
|
||||
# Change to tuple[str, str] when dropping Python < 3.9.
|
||||
Iterable[Tuple[str, str]],
|
||||
SupportsKeysAndGetItem,
|
||||
]
|
||||
"""
|
||||
Types accepted where :class:`Headers` is expected.
|
||||
|
||||
In addition to :class:`Headers` itself, this includes dict-like types where both
|
||||
keys and values are :class:`str`.
|
||||
|
||||
"""
|
@ -1,392 +0,0 @@
|
||||
"""
|
||||
:mod:`websockets.exceptions` defines the following hierarchy of exceptions.
|
||||
|
||||
* :exc:`WebSocketException`
|
||||
* :exc:`ConnectionClosed`
|
||||
* :exc:`ConnectionClosedOK`
|
||||
* :exc:`ConnectionClosedError`
|
||||
* :exc:`InvalidURI`
|
||||
* :exc:`InvalidHandshake`
|
||||
* :exc:`SecurityError`
|
||||
* :exc:`InvalidMessage` (legacy)
|
||||
* :exc:`InvalidStatus`
|
||||
* :exc:`InvalidStatusCode` (legacy)
|
||||
* :exc:`InvalidHeader`
|
||||
* :exc:`InvalidHeaderFormat`
|
||||
* :exc:`InvalidHeaderValue`
|
||||
* :exc:`InvalidOrigin`
|
||||
* :exc:`InvalidUpgrade`
|
||||
* :exc:`NegotiationError`
|
||||
* :exc:`DuplicateParameter`
|
||||
* :exc:`InvalidParameterName`
|
||||
* :exc:`InvalidParameterValue`
|
||||
* :exc:`AbortHandshake` (legacy)
|
||||
* :exc:`RedirectHandshake` (legacy)
|
||||
* :exc:`ProtocolError` (Sans-I/O)
|
||||
* :exc:`PayloadTooBig` (Sans-I/O)
|
||||
* :exc:`InvalidState` (Sans-I/O)
|
||||
* :exc:`ConcurrencyError`
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
import warnings
|
||||
|
||||
from .imports import lazy_import
|
||||
|
||||
|
||||
__all__ = [
|
||||
"WebSocketException",
|
||||
"ConnectionClosed",
|
||||
"ConnectionClosedOK",
|
||||
"ConnectionClosedError",
|
||||
"InvalidURI",
|
||||
"InvalidHandshake",
|
||||
"SecurityError",
|
||||
"InvalidMessage",
|
||||
"InvalidStatus",
|
||||
"InvalidStatusCode",
|
||||
"InvalidHeader",
|
||||
"InvalidHeaderFormat",
|
||||
"InvalidHeaderValue",
|
||||
"InvalidOrigin",
|
||||
"InvalidUpgrade",
|
||||
"NegotiationError",
|
||||
"DuplicateParameter",
|
||||
"InvalidParameterName",
|
||||
"InvalidParameterValue",
|
||||
"AbortHandshake",
|
||||
"RedirectHandshake",
|
||||
"ProtocolError",
|
||||
"WebSocketProtocolError",
|
||||
"PayloadTooBig",
|
||||
"InvalidState",
|
||||
"ConcurrencyError",
|
||||
]
|
||||
|
||||
|
||||
class WebSocketException(Exception):
|
||||
"""
|
||||
Base class for all exceptions defined by websockets.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class ConnectionClosed(WebSocketException):
|
||||
"""
|
||||
Raised when trying to interact with a closed connection.
|
||||
|
||||
Attributes:
|
||||
rcvd: If a close frame was received, its code and reason are available
|
||||
in ``rcvd.code`` and ``rcvd.reason``.
|
||||
sent: If a close frame was sent, its code and reason are available
|
||||
in ``sent.code`` and ``sent.reason``.
|
||||
rcvd_then_sent: If close frames were received and sent, this attribute
|
||||
tells in which order this happened, from the perspective of this
|
||||
side of the connection.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rcvd: frames.Close | None,
|
||||
sent: frames.Close | None,
|
||||
rcvd_then_sent: bool | None = None,
|
||||
) -> None:
|
||||
self.rcvd = rcvd
|
||||
self.sent = sent
|
||||
self.rcvd_then_sent = rcvd_then_sent
|
||||
assert (self.rcvd_then_sent is None) == (self.rcvd is None or self.sent is None)
|
||||
|
||||
def __str__(self) -> str:
|
||||
if self.rcvd is None:
|
||||
if self.sent is None:
|
||||
return "no close frame received or sent"
|
||||
else:
|
||||
return f"sent {self.sent}; no close frame received"
|
||||
else:
|
||||
if self.sent is None:
|
||||
return f"received {self.rcvd}; no close frame sent"
|
||||
else:
|
||||
if self.rcvd_then_sent:
|
||||
return f"received {self.rcvd}; then sent {self.sent}"
|
||||
else:
|
||||
return f"sent {self.sent}; then received {self.rcvd}"
|
||||
|
||||
# code and reason attributes are provided for backwards-compatibility
|
||||
|
||||
@property
|
||||
def code(self) -> int:
|
||||
warnings.warn( # deprecated in 13.1
|
||||
"ConnectionClosed.code is deprecated; "
|
||||
"use Protocol.close_code or ConnectionClosed.rcvd.code",
|
||||
DeprecationWarning,
|
||||
)
|
||||
if self.rcvd is None:
|
||||
return frames.CloseCode.ABNORMAL_CLOSURE
|
||||
return self.rcvd.code
|
||||
|
||||
@property
|
||||
def reason(self) -> str:
|
||||
warnings.warn( # deprecated in 13.1
|
||||
"ConnectionClosed.reason is deprecated; "
|
||||
"use Protocol.close_reason or ConnectionClosed.rcvd.reason",
|
||||
DeprecationWarning,
|
||||
)
|
||||
if self.rcvd is None:
|
||||
return ""
|
||||
return self.rcvd.reason
|
||||
|
||||
|
||||
class ConnectionClosedOK(ConnectionClosed):
|
||||
"""
|
||||
Like :exc:`ConnectionClosed`, when the connection terminated properly.
|
||||
|
||||
A close code with code 1000 (OK) or 1001 (going away) or without a code was
|
||||
received and sent.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class ConnectionClosedError(ConnectionClosed):
|
||||
"""
|
||||
Like :exc:`ConnectionClosed`, when the connection terminated with an error.
|
||||
|
||||
A close frame with a code other than 1000 (OK) or 1001 (going away) was
|
||||
received or sent, or the closing handshake didn't complete properly.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class InvalidURI(WebSocketException):
|
||||
"""
|
||||
Raised when connecting to a URI that isn't a valid WebSocket URI.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, uri: str, msg: str) -> None:
|
||||
self.uri = uri
|
||||
self.msg = msg
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.uri} isn't a valid URI: {self.msg}"
|
||||
|
||||
|
||||
class InvalidHandshake(WebSocketException):
|
||||
"""
|
||||
Base class for exceptions raised when the opening handshake fails.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class SecurityError(InvalidHandshake):
|
||||
"""
|
||||
Raised when a handshake request or response breaks a security rule.
|
||||
|
||||
Security limits can be configured with :doc:`environment variables
|
||||
<../reference/variables>`.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class InvalidStatus(InvalidHandshake):
|
||||
"""
|
||||
Raised when a handshake response rejects the WebSocket upgrade.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, response: http11.Response) -> None:
|
||||
self.response = response
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
"server rejected WebSocket connection: "
|
||||
f"HTTP {self.response.status_code:d}"
|
||||
)
|
||||
|
||||
|
||||
class InvalidHeader(InvalidHandshake):
|
||||
"""
|
||||
Raised when an HTTP header doesn't have a valid format or value.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, value: str | None = None) -> None:
|
||||
self.name = name
|
||||
self.value = value
|
||||
|
||||
def __str__(self) -> str:
|
||||
if self.value is None:
|
||||
return f"missing {self.name} header"
|
||||
elif self.value == "":
|
||||
return f"empty {self.name} header"
|
||||
else:
|
||||
return f"invalid {self.name} header: {self.value}"
|
||||
|
||||
|
||||
class InvalidHeaderFormat(InvalidHeader):
|
||||
"""
|
||||
Raised when an HTTP header cannot be parsed.
|
||||
|
||||
The format of the header doesn't match the grammar for that header.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, error: str, header: str, pos: int) -> None:
|
||||
super().__init__(name, f"{error} at {pos} in {header}")
|
||||
|
||||
|
||||
class InvalidHeaderValue(InvalidHeader):
|
||||
"""
|
||||
Raised when an HTTP header has a wrong value.
|
||||
|
||||
The format of the header is correct but the value isn't acceptable.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class InvalidOrigin(InvalidHeader):
|
||||
"""
|
||||
Raised when the Origin header in a request isn't allowed.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, origin: str | None) -> None:
|
||||
super().__init__("Origin", origin)
|
||||
|
||||
|
||||
class InvalidUpgrade(InvalidHeader):
|
||||
"""
|
||||
Raised when the Upgrade or Connection header isn't correct.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class NegotiationError(InvalidHandshake):
|
||||
"""
|
||||
Raised when negotiating an extension or a subprotocol fails.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class DuplicateParameter(NegotiationError):
|
||||
"""
|
||||
Raised when a parameter name is repeated in an extension header.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
self.name = name
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"duplicate parameter: {self.name}"
|
||||
|
||||
|
||||
class InvalidParameterName(NegotiationError):
|
||||
"""
|
||||
Raised when a parameter name in an extension header is invalid.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
self.name = name
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"invalid parameter name: {self.name}"
|
||||
|
||||
|
||||
class InvalidParameterValue(NegotiationError):
|
||||
"""
|
||||
Raised when a parameter value in an extension header is invalid.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, value: str | None) -> None:
|
||||
self.name = name
|
||||
self.value = value
|
||||
|
||||
def __str__(self) -> str:
|
||||
if self.value is None:
|
||||
return f"missing value for parameter {self.name}"
|
||||
elif self.value == "":
|
||||
return f"empty value for parameter {self.name}"
|
||||
else:
|
||||
return f"invalid value for parameter {self.name}: {self.value}"
|
||||
|
||||
|
||||
class ProtocolError(WebSocketException):
|
||||
"""
|
||||
Raised when receiving or sending a frame that breaks the protocol.
|
||||
|
||||
The Sans-I/O implementation raises this exception when:
|
||||
|
||||
* receiving or sending a frame that contains invalid data;
|
||||
* receiving or sending an invalid sequence of frames.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class PayloadTooBig(WebSocketException):
|
||||
"""
|
||||
Raised when parsing a frame with a payload that exceeds the maximum size.
|
||||
|
||||
The Sans-I/O layer uses this exception internally. It doesn't bubble up to
|
||||
the I/O layer.
|
||||
|
||||
The :meth:`~websockets.extensions.Extension.decode` method of extensions
|
||||
must raise :exc:`PayloadTooBig` if decoding a frame would exceed the limit.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class InvalidState(WebSocketException, AssertionError):
|
||||
"""
|
||||
Raised when sending a frame is forbidden in the current state.
|
||||
|
||||
Specifically, the Sans-I/O layer raises this exception when:
|
||||
|
||||
* sending a data frame to a connection in a state other
|
||||
:attr:`~websockets.protocol.State.OPEN`;
|
||||
* sending a control frame to a connection in a state other than
|
||||
:attr:`~websockets.protocol.State.OPEN` or
|
||||
:attr:`~websockets.protocol.State.CLOSING`.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class ConcurrencyError(WebSocketException, RuntimeError):
|
||||
"""
|
||||
Raised when receiving or sending messages concurrently.
|
||||
|
||||
WebSocket is a connection-oriented protocol. Reads must be serialized; so
|
||||
must be writes. However, reading and writing concurrently is possible.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
# When type checking, import non-deprecated aliases eagerly. Else, import on demand.
|
||||
if typing.TYPE_CHECKING:
|
||||
from .legacy.exceptions import (
|
||||
AbortHandshake,
|
||||
InvalidMessage,
|
||||
InvalidStatusCode,
|
||||
RedirectHandshake,
|
||||
)
|
||||
|
||||
WebSocketProtocolError = ProtocolError
|
||||
else:
|
||||
lazy_import(
|
||||
globals(),
|
||||
aliases={
|
||||
"AbortHandshake": ".legacy.exceptions",
|
||||
"InvalidMessage": ".legacy.exceptions",
|
||||
"InvalidStatusCode": ".legacy.exceptions",
|
||||
"RedirectHandshake": ".legacy.exceptions",
|
||||
"WebSocketProtocolError": ".legacy.exceptions",
|
||||
},
|
||||
)
|
||||
|
||||
# At the bottom to break import cycles created by type annotations.
|
||||
from . import frames, http11 # noqa: E402
|
@ -1,4 +0,0 @@
|
||||
from .base import *
|
||||
|
||||
|
||||
__all__ = ["Extension", "ClientExtensionFactory", "ServerExtensionFactory"]
|
@ -1,123 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Sequence
|
||||
|
||||
from ..frames import Frame
|
||||
from ..typing import ExtensionName, ExtensionParameter
|
||||
|
||||
|
||||
__all__ = ["Extension", "ClientExtensionFactory", "ServerExtensionFactory"]
|
||||
|
||||
|
||||
class Extension:
|
||||
"""
|
||||
Base class for extensions.
|
||||
|
||||
"""
|
||||
|
||||
name: ExtensionName
|
||||
"""Extension identifier."""
|
||||
|
||||
def decode(self, frame: Frame, *, max_size: int | None = None) -> Frame:
|
||||
"""
|
||||
Decode an incoming frame.
|
||||
|
||||
Args:
|
||||
frame: Incoming frame.
|
||||
max_size: Maximum payload size in bytes.
|
||||
|
||||
Returns:
|
||||
Decoded frame.
|
||||
|
||||
Raises:
|
||||
PayloadTooBig: If decoding the payload exceeds ``max_size``.
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def encode(self, frame: Frame) -> Frame:
|
||||
"""
|
||||
Encode an outgoing frame.
|
||||
|
||||
Args:
|
||||
frame: Outgoing frame.
|
||||
|
||||
Returns:
|
||||
Encoded frame.
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ClientExtensionFactory:
|
||||
"""
|
||||
Base class for client-side extension factories.
|
||||
|
||||
"""
|
||||
|
||||
name: ExtensionName
|
||||
"""Extension identifier."""
|
||||
|
||||
def get_request_params(self) -> list[ExtensionParameter]:
|
||||
"""
|
||||
Build parameters to send to the server for this extension.
|
||||
|
||||
Returns:
|
||||
Parameters to send to the server.
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def process_response_params(
|
||||
self,
|
||||
params: Sequence[ExtensionParameter],
|
||||
accepted_extensions: Sequence[Extension],
|
||||
) -> Extension:
|
||||
"""
|
||||
Process parameters received from the server.
|
||||
|
||||
Args:
|
||||
params: Parameters received from the server for this extension.
|
||||
accepted_extensions: List of previously accepted extensions.
|
||||
|
||||
Returns:
|
||||
An extension instance.
|
||||
|
||||
Raises:
|
||||
NegotiationError: If parameters aren't acceptable.
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ServerExtensionFactory:
|
||||
"""
|
||||
Base class for server-side extension factories.
|
||||
|
||||
"""
|
||||
|
||||
name: ExtensionName
|
||||
"""Extension identifier."""
|
||||
|
||||
def process_request_params(
|
||||
self,
|
||||
params: Sequence[ExtensionParameter],
|
||||
accepted_extensions: Sequence[Extension],
|
||||
) -> tuple[list[ExtensionParameter], Extension]:
|
||||
"""
|
||||
Process parameters received from the client.
|
||||
|
||||
Args:
|
||||
params: Parameters received from the client for this extension.
|
||||
accepted_extensions: List of previously accepted extensions.
|
||||
|
||||
Returns:
|
||||
To accept the offer, parameters to send to the client for this
|
||||
extension and an extension instance.
|
||||
|
||||
Raises:
|
||||
NegotiationError: To reject the offer, if parameters received from
|
||||
the client aren't acceptable.
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
@ -1,670 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import zlib
|
||||
from typing import Any, Sequence
|
||||
|
||||
from .. import frames
|
||||
from ..exceptions import (
|
||||
DuplicateParameter,
|
||||
InvalidParameterName,
|
||||
InvalidParameterValue,
|
||||
NegotiationError,
|
||||
PayloadTooBig,
|
||||
ProtocolError,
|
||||
)
|
||||
from ..typing import ExtensionName, ExtensionParameter
|
||||
from .base import ClientExtensionFactory, Extension, ServerExtensionFactory
|
||||
|
||||
|
||||
__all__ = [
|
||||
"PerMessageDeflate",
|
||||
"ClientPerMessageDeflateFactory",
|
||||
"enable_client_permessage_deflate",
|
||||
"ServerPerMessageDeflateFactory",
|
||||
"enable_server_permessage_deflate",
|
||||
]
|
||||
|
||||
_EMPTY_UNCOMPRESSED_BLOCK = b"\x00\x00\xff\xff"
|
||||
|
||||
_MAX_WINDOW_BITS_VALUES = [str(bits) for bits in range(8, 16)]
|
||||
|
||||
|
||||
class PerMessageDeflate(Extension):
|
||||
"""
|
||||
Per-Message Deflate extension.
|
||||
|
||||
"""
|
||||
|
||||
name = ExtensionName("permessage-deflate")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
remote_no_context_takeover: bool,
|
||||
local_no_context_takeover: bool,
|
||||
remote_max_window_bits: int,
|
||||
local_max_window_bits: int,
|
||||
compress_settings: dict[Any, Any] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Configure the Per-Message Deflate extension.
|
||||
|
||||
"""
|
||||
if compress_settings is None:
|
||||
compress_settings = {}
|
||||
|
||||
assert remote_no_context_takeover in [False, True]
|
||||
assert local_no_context_takeover in [False, True]
|
||||
assert 8 <= remote_max_window_bits <= 15
|
||||
assert 8 <= local_max_window_bits <= 15
|
||||
assert "wbits" not in compress_settings
|
||||
|
||||
self.remote_no_context_takeover = remote_no_context_takeover
|
||||
self.local_no_context_takeover = local_no_context_takeover
|
||||
self.remote_max_window_bits = remote_max_window_bits
|
||||
self.local_max_window_bits = local_max_window_bits
|
||||
self.compress_settings = compress_settings
|
||||
|
||||
if not self.remote_no_context_takeover:
|
||||
self.decoder = zlib.decompressobj(wbits=-self.remote_max_window_bits)
|
||||
|
||||
if not self.local_no_context_takeover:
|
||||
self.encoder = zlib.compressobj(
|
||||
wbits=-self.local_max_window_bits,
|
||||
**self.compress_settings,
|
||||
)
|
||||
|
||||
# To handle continuation frames properly, we must keep track of
|
||||
# whether that initial frame was encoded.
|
||||
self.decode_cont_data = False
|
||||
# There's no need for self.encode_cont_data because we always encode
|
||||
# outgoing frames, so it would always be True.
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"PerMessageDeflate("
|
||||
f"remote_no_context_takeover={self.remote_no_context_takeover}, "
|
||||
f"local_no_context_takeover={self.local_no_context_takeover}, "
|
||||
f"remote_max_window_bits={self.remote_max_window_bits}, "
|
||||
f"local_max_window_bits={self.local_max_window_bits})"
|
||||
)
|
||||
|
||||
def decode(
|
||||
self,
|
||||
frame: frames.Frame,
|
||||
*,
|
||||
max_size: int | None = None,
|
||||
) -> frames.Frame:
|
||||
"""
|
||||
Decode an incoming frame.
|
||||
|
||||
"""
|
||||
# Skip control frames.
|
||||
if frame.opcode in frames.CTRL_OPCODES:
|
||||
return frame
|
||||
|
||||
# Handle continuation data frames:
|
||||
# - skip if the message isn't encoded
|
||||
# - reset "decode continuation data" flag if it's a final frame
|
||||
if frame.opcode is frames.OP_CONT:
|
||||
if not self.decode_cont_data:
|
||||
return frame
|
||||
if frame.fin:
|
||||
self.decode_cont_data = False
|
||||
|
||||
# Handle text and binary data frames:
|
||||
# - skip if the message isn't encoded
|
||||
# - unset the rsv1 flag on the first frame of a compressed message
|
||||
# - set "decode continuation data" flag if it's a non-final frame
|
||||
else:
|
||||
if not frame.rsv1:
|
||||
return frame
|
||||
frame = dataclasses.replace(frame, rsv1=False)
|
||||
if not frame.fin:
|
||||
self.decode_cont_data = True
|
||||
|
||||
# Re-initialize per-message decoder.
|
||||
if self.remote_no_context_takeover:
|
||||
self.decoder = zlib.decompressobj(wbits=-self.remote_max_window_bits)
|
||||
|
||||
# Uncompress data. Protect against zip bombs by preventing zlib from
|
||||
# decompressing more than max_length bytes (except when the limit is
|
||||
# disabled with max_size = None).
|
||||
data = frame.data
|
||||
if frame.fin:
|
||||
data += _EMPTY_UNCOMPRESSED_BLOCK
|
||||
max_length = 0 if max_size is None else max_size
|
||||
try:
|
||||
data = self.decoder.decompress(data, max_length)
|
||||
except zlib.error as exc:
|
||||
raise ProtocolError("decompression failed") from exc
|
||||
if self.decoder.unconsumed_tail:
|
||||
raise PayloadTooBig(f"over size limit (? > {max_size} bytes)")
|
||||
|
||||
# Allow garbage collection of the decoder if it won't be reused.
|
||||
if frame.fin and self.remote_no_context_takeover:
|
||||
del self.decoder
|
||||
|
||||
return dataclasses.replace(frame, data=data)
|
||||
|
||||
def encode(self, frame: frames.Frame) -> frames.Frame:
|
||||
"""
|
||||
Encode an outgoing frame.
|
||||
|
||||
"""
|
||||
# Skip control frames.
|
||||
if frame.opcode in frames.CTRL_OPCODES:
|
||||
return frame
|
||||
|
||||
# Since we always encode messages, there's no "encode continuation
|
||||
# data" flag similar to "decode continuation data" at this time.
|
||||
|
||||
if frame.opcode is not frames.OP_CONT:
|
||||
# Set the rsv1 flag on the first frame of a compressed message.
|
||||
frame = dataclasses.replace(frame, rsv1=True)
|
||||
# Re-initialize per-message decoder.
|
||||
if self.local_no_context_takeover:
|
||||
self.encoder = zlib.compressobj(
|
||||
wbits=-self.local_max_window_bits,
|
||||
**self.compress_settings,
|
||||
)
|
||||
|
||||
# Compress data.
|
||||
data = self.encoder.compress(frame.data) + self.encoder.flush(zlib.Z_SYNC_FLUSH)
|
||||
if frame.fin and data.endswith(_EMPTY_UNCOMPRESSED_BLOCK):
|
||||
data = data[:-4]
|
||||
|
||||
# Allow garbage collection of the encoder if it won't be reused.
|
||||
if frame.fin and self.local_no_context_takeover:
|
||||
del self.encoder
|
||||
|
||||
return dataclasses.replace(frame, data=data)
|
||||
|
||||
|
||||
def _build_parameters(
|
||||
server_no_context_takeover: bool,
|
||||
client_no_context_takeover: bool,
|
||||
server_max_window_bits: int | None,
|
||||
client_max_window_bits: int | bool | None,
|
||||
) -> list[ExtensionParameter]:
|
||||
"""
|
||||
Build a list of ``(name, value)`` pairs for some compression parameters.
|
||||
|
||||
"""
|
||||
params: list[ExtensionParameter] = []
|
||||
if server_no_context_takeover:
|
||||
params.append(("server_no_context_takeover", None))
|
||||
if client_no_context_takeover:
|
||||
params.append(("client_no_context_takeover", None))
|
||||
if server_max_window_bits:
|
||||
params.append(("server_max_window_bits", str(server_max_window_bits)))
|
||||
if client_max_window_bits is True: # only in handshake requests
|
||||
params.append(("client_max_window_bits", None))
|
||||
elif client_max_window_bits:
|
||||
params.append(("client_max_window_bits", str(client_max_window_bits)))
|
||||
return params
|
||||
|
||||
|
||||
def _extract_parameters(
|
||||
params: Sequence[ExtensionParameter], *, is_server: bool
|
||||
) -> tuple[bool, bool, int | None, int | bool | None]:
|
||||
"""
|
||||
Extract compression parameters from a list of ``(name, value)`` pairs.
|
||||
|
||||
If ``is_server`` is :obj:`True`, ``client_max_window_bits`` may be
|
||||
provided without a value. This is only allowed in handshake requests.
|
||||
|
||||
"""
|
||||
server_no_context_takeover: bool = False
|
||||
client_no_context_takeover: bool = False
|
||||
server_max_window_bits: int | None = None
|
||||
client_max_window_bits: int | bool | None = None
|
||||
|
||||
for name, value in params:
|
||||
if name == "server_no_context_takeover":
|
||||
if server_no_context_takeover:
|
||||
raise DuplicateParameter(name)
|
||||
if value is None:
|
||||
server_no_context_takeover = True
|
||||
else:
|
||||
raise InvalidParameterValue(name, value)
|
||||
|
||||
elif name == "client_no_context_takeover":
|
||||
if client_no_context_takeover:
|
||||
raise DuplicateParameter(name)
|
||||
if value is None:
|
||||
client_no_context_takeover = True
|
||||
else:
|
||||
raise InvalidParameterValue(name, value)
|
||||
|
||||
elif name == "server_max_window_bits":
|
||||
if server_max_window_bits is not None:
|
||||
raise DuplicateParameter(name)
|
||||
if value in _MAX_WINDOW_BITS_VALUES:
|
||||
server_max_window_bits = int(value)
|
||||
else:
|
||||
raise InvalidParameterValue(name, value)
|
||||
|
||||
elif name == "client_max_window_bits":
|
||||
if client_max_window_bits is not None:
|
||||
raise DuplicateParameter(name)
|
||||
if is_server and value is None: # only in handshake requests
|
||||
client_max_window_bits = True
|
||||
elif value in _MAX_WINDOW_BITS_VALUES:
|
||||
client_max_window_bits = int(value)
|
||||
else:
|
||||
raise InvalidParameterValue(name, value)
|
||||
|
||||
else:
|
||||
raise InvalidParameterName(name)
|
||||
|
||||
return (
|
||||
server_no_context_takeover,
|
||||
client_no_context_takeover,
|
||||
server_max_window_bits,
|
||||
client_max_window_bits,
|
||||
)
|
||||
|
||||
|
||||
class ClientPerMessageDeflateFactory(ClientExtensionFactory):
|
||||
"""
|
||||
Client-side extension factory for the Per-Message Deflate extension.
|
||||
|
||||
Parameters behave as described in `section 7.1 of RFC 7692`_.
|
||||
|
||||
.. _section 7.1 of RFC 7692: https://datatracker.ietf.org/doc/html/rfc7692#section-7.1
|
||||
|
||||
Set them to :obj:`True` to include them in the negotiation offer without a
|
||||
value or to an integer value to include them with this value.
|
||||
|
||||
Args:
|
||||
server_no_context_takeover: Prevent server from using context takeover.
|
||||
client_no_context_takeover: Prevent client from using context takeover.
|
||||
server_max_window_bits: Maximum size of the server's LZ77 sliding window
|
||||
in bits, between 8 and 15.
|
||||
client_max_window_bits: Maximum size of the client's LZ77 sliding window
|
||||
in bits, between 8 and 15, or :obj:`True` to indicate support without
|
||||
setting a limit.
|
||||
compress_settings: Additional keyword arguments for :func:`zlib.compressobj`,
|
||||
excluding ``wbits``.
|
||||
|
||||
"""
|
||||
|
||||
name = ExtensionName("permessage-deflate")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_no_context_takeover: bool = False,
|
||||
client_no_context_takeover: bool = False,
|
||||
server_max_window_bits: int | None = None,
|
||||
client_max_window_bits: int | bool | None = True,
|
||||
compress_settings: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Configure the Per-Message Deflate extension factory.
|
||||
|
||||
"""
|
||||
if not (server_max_window_bits is None or 8 <= server_max_window_bits <= 15):
|
||||
raise ValueError("server_max_window_bits must be between 8 and 15")
|
||||
if not (
|
||||
client_max_window_bits is None
|
||||
or client_max_window_bits is True
|
||||
or 8 <= client_max_window_bits <= 15
|
||||
):
|
||||
raise ValueError("client_max_window_bits must be between 8 and 15")
|
||||
if compress_settings is not None and "wbits" in compress_settings:
|
||||
raise ValueError(
|
||||
"compress_settings must not include wbits, "
|
||||
"set client_max_window_bits instead"
|
||||
)
|
||||
|
||||
self.server_no_context_takeover = server_no_context_takeover
|
||||
self.client_no_context_takeover = client_no_context_takeover
|
||||
self.server_max_window_bits = server_max_window_bits
|
||||
self.client_max_window_bits = client_max_window_bits
|
||||
self.compress_settings = compress_settings
|
||||
|
||||
def get_request_params(self) -> list[ExtensionParameter]:
|
||||
"""
|
||||
Build request parameters.
|
||||
|
||||
"""
|
||||
return _build_parameters(
|
||||
self.server_no_context_takeover,
|
||||
self.client_no_context_takeover,
|
||||
self.server_max_window_bits,
|
||||
self.client_max_window_bits,
|
||||
)
|
||||
|
||||
def process_response_params(
|
||||
self,
|
||||
params: Sequence[ExtensionParameter],
|
||||
accepted_extensions: Sequence[Extension],
|
||||
) -> PerMessageDeflate:
|
||||
"""
|
||||
Process response parameters.
|
||||
|
||||
Return an extension instance.
|
||||
|
||||
"""
|
||||
if any(other.name == self.name for other in accepted_extensions):
|
||||
raise NegotiationError(f"received duplicate {self.name}")
|
||||
|
||||
# Request parameters are available in instance variables.
|
||||
|
||||
# Load response parameters in local variables.
|
||||
(
|
||||
server_no_context_takeover,
|
||||
client_no_context_takeover,
|
||||
server_max_window_bits,
|
||||
client_max_window_bits,
|
||||
) = _extract_parameters(params, is_server=False)
|
||||
|
||||
# After comparing the request and the response, the final
|
||||
# configuration must be available in the local variables.
|
||||
|
||||
# server_no_context_takeover
|
||||
#
|
||||
# Req. Resp. Result
|
||||
# ------ ------ --------------------------------------------------
|
||||
# False False False
|
||||
# False True True
|
||||
# True False Error!
|
||||
# True True True
|
||||
|
||||
if self.server_no_context_takeover:
|
||||
if not server_no_context_takeover:
|
||||
raise NegotiationError("expected server_no_context_takeover")
|
||||
|
||||
# client_no_context_takeover
|
||||
#
|
||||
# Req. Resp. Result
|
||||
# ------ ------ --------------------------------------------------
|
||||
# False False False
|
||||
# False True True
|
||||
# True False True - must change value
|
||||
# True True True
|
||||
|
||||
if self.client_no_context_takeover:
|
||||
if not client_no_context_takeover:
|
||||
client_no_context_takeover = True
|
||||
|
||||
# server_max_window_bits
|
||||
|
||||
# Req. Resp. Result
|
||||
# ------ ------ --------------------------------------------------
|
||||
# None None None
|
||||
# None 8≤M≤15 M
|
||||
# 8≤N≤15 None Error!
|
||||
# 8≤N≤15 8≤M≤N M
|
||||
# 8≤N≤15 N<M≤15 Error!
|
||||
|
||||
if self.server_max_window_bits is None:
|
||||
pass
|
||||
|
||||
else:
|
||||
if server_max_window_bits is None:
|
||||
raise NegotiationError("expected server_max_window_bits")
|
||||
elif server_max_window_bits > self.server_max_window_bits:
|
||||
raise NegotiationError("unsupported server_max_window_bits")
|
||||
|
||||
# client_max_window_bits
|
||||
|
||||
# Req. Resp. Result
|
||||
# ------ ------ --------------------------------------------------
|
||||
# None None None
|
||||
# None 8≤M≤15 Error!
|
||||
# True None None
|
||||
# True 8≤M≤15 M
|
||||
# 8≤N≤15 None N - must change value
|
||||
# 8≤N≤15 8≤M≤N M
|
||||
# 8≤N≤15 N<M≤15 Error!
|
||||
|
||||
if self.client_max_window_bits is None:
|
||||
if client_max_window_bits is not None:
|
||||
raise NegotiationError("unexpected client_max_window_bits")
|
||||
|
||||
elif self.client_max_window_bits is True:
|
||||
pass
|
||||
|
||||
else:
|
||||
if client_max_window_bits is None:
|
||||
client_max_window_bits = self.client_max_window_bits
|
||||
elif client_max_window_bits > self.client_max_window_bits:
|
||||
raise NegotiationError("unsupported client_max_window_bits")
|
||||
|
||||
return PerMessageDeflate(
|
||||
server_no_context_takeover, # remote_no_context_takeover
|
||||
client_no_context_takeover, # local_no_context_takeover
|
||||
server_max_window_bits or 15, # remote_max_window_bits
|
||||
client_max_window_bits or 15, # local_max_window_bits
|
||||
self.compress_settings,
|
||||
)
|
||||
|
||||
|
||||
def enable_client_permessage_deflate(
|
||||
extensions: Sequence[ClientExtensionFactory] | None,
|
||||
) -> Sequence[ClientExtensionFactory]:
|
||||
"""
|
||||
Enable Per-Message Deflate with default settings in client extensions.
|
||||
|
||||
If the extension is already present, perhaps with non-default settings,
|
||||
the configuration isn't changed.
|
||||
|
||||
"""
|
||||
if extensions is None:
|
||||
extensions = []
|
||||
if not any(
|
||||
extension_factory.name == ClientPerMessageDeflateFactory.name
|
||||
for extension_factory in extensions
|
||||
):
|
||||
extensions = list(extensions) + [
|
||||
ClientPerMessageDeflateFactory(
|
||||
compress_settings={"memLevel": 5},
|
||||
)
|
||||
]
|
||||
return extensions
|
||||
|
||||
|
||||
class ServerPerMessageDeflateFactory(ServerExtensionFactory):
|
||||
"""
|
||||
Server-side extension factory for the Per-Message Deflate extension.
|
||||
|
||||
Parameters behave as described in `section 7.1 of RFC 7692`_.
|
||||
|
||||
.. _section 7.1 of RFC 7692: https://datatracker.ietf.org/doc/html/rfc7692#section-7.1
|
||||
|
||||
Set them to :obj:`True` to include them in the negotiation offer without a
|
||||
value or to an integer value to include them with this value.
|
||||
|
||||
Args:
|
||||
server_no_context_takeover: Prevent server from using context takeover.
|
||||
client_no_context_takeover: Prevent client from using context takeover.
|
||||
server_max_window_bits: Maximum size of the server's LZ77 sliding window
|
||||
in bits, between 8 and 15.
|
||||
client_max_window_bits: Maximum size of the client's LZ77 sliding window
|
||||
in bits, between 8 and 15.
|
||||
compress_settings: Additional keyword arguments for :func:`zlib.compressobj`,
|
||||
excluding ``wbits``.
|
||||
require_client_max_window_bits: Do not enable compression at all if
|
||||
client doesn't advertise support for ``client_max_window_bits``;
|
||||
the default behavior is to enable compression without enforcing
|
||||
``client_max_window_bits``.
|
||||
|
||||
"""
|
||||
|
||||
name = ExtensionName("permessage-deflate")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_no_context_takeover: bool = False,
|
||||
client_no_context_takeover: bool = False,
|
||||
server_max_window_bits: int | None = None,
|
||||
client_max_window_bits: int | None = None,
|
||||
compress_settings: dict[str, Any] | None = None,
|
||||
require_client_max_window_bits: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Configure the Per-Message Deflate extension factory.
|
||||
|
||||
"""
|
||||
if not (server_max_window_bits is None or 8 <= server_max_window_bits <= 15):
|
||||
raise ValueError("server_max_window_bits must be between 8 and 15")
|
||||
if not (client_max_window_bits is None or 8 <= client_max_window_bits <= 15):
|
||||
raise ValueError("client_max_window_bits must be between 8 and 15")
|
||||
if compress_settings is not None and "wbits" in compress_settings:
|
||||
raise ValueError(
|
||||
"compress_settings must not include wbits, "
|
||||
"set server_max_window_bits instead"
|
||||
)
|
||||
if client_max_window_bits is None and require_client_max_window_bits:
|
||||
raise ValueError(
|
||||
"require_client_max_window_bits is enabled, "
|
||||
"but client_max_window_bits isn't configured"
|
||||
)
|
||||
|
||||
self.server_no_context_takeover = server_no_context_takeover
|
||||
self.client_no_context_takeover = client_no_context_takeover
|
||||
self.server_max_window_bits = server_max_window_bits
|
||||
self.client_max_window_bits = client_max_window_bits
|
||||
self.compress_settings = compress_settings
|
||||
self.require_client_max_window_bits = require_client_max_window_bits
|
||||
|
||||
def process_request_params(
|
||||
self,
|
||||
params: Sequence[ExtensionParameter],
|
||||
accepted_extensions: Sequence[Extension],
|
||||
) -> tuple[list[ExtensionParameter], PerMessageDeflate]:
|
||||
"""
|
||||
Process request parameters.
|
||||
|
||||
Return response params and an extension instance.
|
||||
|
||||
"""
|
||||
if any(other.name == self.name for other in accepted_extensions):
|
||||
raise NegotiationError(f"skipped duplicate {self.name}")
|
||||
|
||||
# Load request parameters in local variables.
|
||||
(
|
||||
server_no_context_takeover,
|
||||
client_no_context_takeover,
|
||||
server_max_window_bits,
|
||||
client_max_window_bits,
|
||||
) = _extract_parameters(params, is_server=True)
|
||||
|
||||
# Configuration parameters are available in instance variables.
|
||||
|
||||
# After comparing the request and the configuration, the response must
|
||||
# be available in the local variables.
|
||||
|
||||
# server_no_context_takeover
|
||||
#
|
||||
# Config Req. Resp.
|
||||
# ------ ------ --------------------------------------------------
|
||||
# False False False
|
||||
# False True True
|
||||
# True False True - must change value to True
|
||||
# True True True
|
||||
|
||||
if self.server_no_context_takeover:
|
||||
if not server_no_context_takeover:
|
||||
server_no_context_takeover = True
|
||||
|
||||
# client_no_context_takeover
|
||||
#
|
||||
# Config Req. Resp.
|
||||
# ------ ------ --------------------------------------------------
|
||||
# False False False
|
||||
# False True True (or False)
|
||||
# True False True - must change value to True
|
||||
# True True True (or False)
|
||||
|
||||
if self.client_no_context_takeover:
|
||||
if not client_no_context_takeover:
|
||||
client_no_context_takeover = True
|
||||
|
||||
# server_max_window_bits
|
||||
|
||||
# Config Req. Resp.
|
||||
# ------ ------ --------------------------------------------------
|
||||
# None None None
|
||||
# None 8≤M≤15 M
|
||||
# 8≤N≤15 None N - must change value
|
||||
# 8≤N≤15 8≤M≤N M
|
||||
# 8≤N≤15 N<M≤15 N - must change value
|
||||
|
||||
if self.server_max_window_bits is None:
|
||||
pass
|
||||
|
||||
else:
|
||||
if server_max_window_bits is None:
|
||||
server_max_window_bits = self.server_max_window_bits
|
||||
elif server_max_window_bits > self.server_max_window_bits:
|
||||
server_max_window_bits = self.server_max_window_bits
|
||||
|
||||
# client_max_window_bits
|
||||
|
||||
# Config Req. Resp.
|
||||
# ------ ------ --------------------------------------------------
|
||||
# None None None
|
||||
# None True None - must change value
|
||||
# None 8≤M≤15 M (or None)
|
||||
# 8≤N≤15 None None or Error!
|
||||
# 8≤N≤15 True N - must change value
|
||||
# 8≤N≤15 8≤M≤N M (or None)
|
||||
# 8≤N≤15 N<M≤15 N
|
||||
|
||||
if self.client_max_window_bits is None:
|
||||
if client_max_window_bits is True:
|
||||
client_max_window_bits = self.client_max_window_bits
|
||||
|
||||
else:
|
||||
if client_max_window_bits is None:
|
||||
if self.require_client_max_window_bits:
|
||||
raise NegotiationError("required client_max_window_bits")
|
||||
elif client_max_window_bits is True:
|
||||
client_max_window_bits = self.client_max_window_bits
|
||||
elif self.client_max_window_bits < client_max_window_bits:
|
||||
client_max_window_bits = self.client_max_window_bits
|
||||
|
||||
return (
|
||||
_build_parameters(
|
||||
server_no_context_takeover,
|
||||
client_no_context_takeover,
|
||||
server_max_window_bits,
|
||||
client_max_window_bits,
|
||||
),
|
||||
PerMessageDeflate(
|
||||
client_no_context_takeover, # remote_no_context_takeover
|
||||
server_no_context_takeover, # local_no_context_takeover
|
||||
client_max_window_bits or 15, # remote_max_window_bits
|
||||
server_max_window_bits or 15, # local_max_window_bits
|
||||
self.compress_settings,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def enable_server_permessage_deflate(
|
||||
extensions: Sequence[ServerExtensionFactory] | None,
|
||||
) -> Sequence[ServerExtensionFactory]:
|
||||
"""
|
||||
Enable Per-Message Deflate with default settings in server extensions.
|
||||
|
||||
If the extension is already present, perhaps with non-default settings,
|
||||
the configuration isn't changed.
|
||||
|
||||
"""
|
||||
if extensions is None:
|
||||
extensions = []
|
||||
if not any(
|
||||
ext_factory.name == ServerPerMessageDeflateFactory.name
|
||||
for ext_factory in extensions
|
||||
):
|
||||
extensions = list(extensions) + [
|
||||
ServerPerMessageDeflateFactory(
|
||||
server_max_window_bits=12,
|
||||
client_max_window_bits=12,
|
||||
compress_settings={"memLevel": 5},
|
||||
)
|
||||
]
|
||||
return extensions
|
@ -1,429 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import enum
|
||||
import io
|
||||
import os
|
||||
import secrets
|
||||
import struct
|
||||
from typing import Callable, Generator, Sequence
|
||||
|
||||
from .exceptions import PayloadTooBig, ProtocolError
|
||||
|
||||
|
||||
try:
|
||||
from .speedups import apply_mask
|
||||
except ImportError:
|
||||
from .utils import apply_mask
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Opcode",
|
||||
"OP_CONT",
|
||||
"OP_TEXT",
|
||||
"OP_BINARY",
|
||||
"OP_CLOSE",
|
||||
"OP_PING",
|
||||
"OP_PONG",
|
||||
"DATA_OPCODES",
|
||||
"CTRL_OPCODES",
|
||||
"Frame",
|
||||
"Close",
|
||||
]
|
||||
|
||||
|
||||
class Opcode(enum.IntEnum):
|
||||
"""Opcode values for WebSocket frames."""
|
||||
|
||||
CONT, TEXT, BINARY = 0x00, 0x01, 0x02
|
||||
CLOSE, PING, PONG = 0x08, 0x09, 0x0A
|
||||
|
||||
|
||||
OP_CONT = Opcode.CONT
|
||||
OP_TEXT = Opcode.TEXT
|
||||
OP_BINARY = Opcode.BINARY
|
||||
OP_CLOSE = Opcode.CLOSE
|
||||
OP_PING = Opcode.PING
|
||||
OP_PONG = Opcode.PONG
|
||||
|
||||
DATA_OPCODES = OP_CONT, OP_TEXT, OP_BINARY
|
||||
CTRL_OPCODES = OP_CLOSE, OP_PING, OP_PONG
|
||||
|
||||
|
||||
class CloseCode(enum.IntEnum):
|
||||
"""Close code values for WebSocket close frames."""
|
||||
|
||||
NORMAL_CLOSURE = 1000
|
||||
GOING_AWAY = 1001
|
||||
PROTOCOL_ERROR = 1002
|
||||
UNSUPPORTED_DATA = 1003
|
||||
# 1004 is reserved
|
||||
NO_STATUS_RCVD = 1005
|
||||
ABNORMAL_CLOSURE = 1006
|
||||
INVALID_DATA = 1007
|
||||
POLICY_VIOLATION = 1008
|
||||
MESSAGE_TOO_BIG = 1009
|
||||
MANDATORY_EXTENSION = 1010
|
||||
INTERNAL_ERROR = 1011
|
||||
SERVICE_RESTART = 1012
|
||||
TRY_AGAIN_LATER = 1013
|
||||
BAD_GATEWAY = 1014
|
||||
TLS_HANDSHAKE = 1015
|
||||
|
||||
|
||||
# See https://www.iana.org/assignments/websocket/websocket.xhtml
|
||||
CLOSE_CODE_EXPLANATIONS: dict[int, str] = {
|
||||
CloseCode.NORMAL_CLOSURE: "OK",
|
||||
CloseCode.GOING_AWAY: "going away",
|
||||
CloseCode.PROTOCOL_ERROR: "protocol error",
|
||||
CloseCode.UNSUPPORTED_DATA: "unsupported data",
|
||||
CloseCode.NO_STATUS_RCVD: "no status received [internal]",
|
||||
CloseCode.ABNORMAL_CLOSURE: "abnormal closure [internal]",
|
||||
CloseCode.INVALID_DATA: "invalid frame payload data",
|
||||
CloseCode.POLICY_VIOLATION: "policy violation",
|
||||
CloseCode.MESSAGE_TOO_BIG: "message too big",
|
||||
CloseCode.MANDATORY_EXTENSION: "mandatory extension",
|
||||
CloseCode.INTERNAL_ERROR: "internal error",
|
||||
CloseCode.SERVICE_RESTART: "service restart",
|
||||
CloseCode.TRY_AGAIN_LATER: "try again later",
|
||||
CloseCode.BAD_GATEWAY: "bad gateway",
|
||||
CloseCode.TLS_HANDSHAKE: "TLS handshake failure [internal]",
|
||||
}
|
||||
|
||||
|
||||
# Close code that are allowed in a close frame.
|
||||
# Using a set optimizes `code in EXTERNAL_CLOSE_CODES`.
|
||||
EXTERNAL_CLOSE_CODES = {
|
||||
CloseCode.NORMAL_CLOSURE,
|
||||
CloseCode.GOING_AWAY,
|
||||
CloseCode.PROTOCOL_ERROR,
|
||||
CloseCode.UNSUPPORTED_DATA,
|
||||
CloseCode.INVALID_DATA,
|
||||
CloseCode.POLICY_VIOLATION,
|
||||
CloseCode.MESSAGE_TOO_BIG,
|
||||
CloseCode.MANDATORY_EXTENSION,
|
||||
CloseCode.INTERNAL_ERROR,
|
||||
CloseCode.SERVICE_RESTART,
|
||||
CloseCode.TRY_AGAIN_LATER,
|
||||
CloseCode.BAD_GATEWAY,
|
||||
}
|
||||
|
||||
|
||||
OK_CLOSE_CODES = {
|
||||
CloseCode.NORMAL_CLOSURE,
|
||||
CloseCode.GOING_AWAY,
|
||||
CloseCode.NO_STATUS_RCVD,
|
||||
}
|
||||
|
||||
|
||||
BytesLike = bytes, bytearray, memoryview
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Frame:
|
||||
"""
|
||||
WebSocket frame.
|
||||
|
||||
Attributes:
|
||||
opcode: Opcode.
|
||||
data: Payload data.
|
||||
fin: FIN bit.
|
||||
rsv1: RSV1 bit.
|
||||
rsv2: RSV2 bit.
|
||||
rsv3: RSV3 bit.
|
||||
|
||||
Only these fields are needed. The MASK bit, payload length and masking-key
|
||||
are handled on the fly when parsing and serializing frames.
|
||||
|
||||
"""
|
||||
|
||||
opcode: Opcode
|
||||
data: bytes
|
||||
fin: bool = True
|
||||
rsv1: bool = False
|
||||
rsv2: bool = False
|
||||
rsv3: bool = False
|
||||
|
||||
# Configure if you want to see more in logs. Should be a multiple of 3.
|
||||
MAX_LOG_SIZE = int(os.environ.get("WEBSOCKETS_MAX_LOG_SIZE", "75"))
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""
|
||||
Return a human-readable representation of a frame.
|
||||
|
||||
"""
|
||||
coding = None
|
||||
length = f"{len(self.data)} byte{'' if len(self.data) == 1 else 's'}"
|
||||
non_final = "" if self.fin else "continued"
|
||||
|
||||
if self.opcode is OP_TEXT:
|
||||
# Decoding only the beginning and the end is needlessly hard.
|
||||
# Decode the entire payload then elide later if necessary.
|
||||
data = repr(self.data.decode())
|
||||
elif self.opcode is OP_BINARY:
|
||||
# We'll show at most the first 16 bytes and the last 8 bytes.
|
||||
# Encode just what we need, plus two dummy bytes to elide later.
|
||||
binary = self.data
|
||||
if len(binary) > self.MAX_LOG_SIZE // 3:
|
||||
cut = (self.MAX_LOG_SIZE // 3 - 1) // 3 # by default cut = 8
|
||||
binary = b"".join([binary[: 2 * cut], b"\x00\x00", binary[-cut:]])
|
||||
data = " ".join(f"{byte:02x}" for byte in binary)
|
||||
elif self.opcode is OP_CLOSE:
|
||||
data = str(Close.parse(self.data))
|
||||
elif self.data:
|
||||
# We don't know if a Continuation frame contains text or binary.
|
||||
# Ping and Pong frames could contain UTF-8.
|
||||
# Attempt to decode as UTF-8 and display it as text; fallback to
|
||||
# binary. If self.data is a memoryview, it has no decode() method,
|
||||
# which raises AttributeError.
|
||||
try:
|
||||
data = repr(self.data.decode())
|
||||
coding = "text"
|
||||
except (UnicodeDecodeError, AttributeError):
|
||||
binary = self.data
|
||||
if len(binary) > self.MAX_LOG_SIZE // 3:
|
||||
cut = (self.MAX_LOG_SIZE // 3 - 1) // 3 # by default cut = 8
|
||||
binary = b"".join([binary[: 2 * cut], b"\x00\x00", binary[-cut:]])
|
||||
data = " ".join(f"{byte:02x}" for byte in binary)
|
||||
coding = "binary"
|
||||
else:
|
||||
data = "''"
|
||||
|
||||
if len(data) > self.MAX_LOG_SIZE:
|
||||
cut = self.MAX_LOG_SIZE // 3 - 1 # by default cut = 24
|
||||
data = data[: 2 * cut] + "..." + data[-cut:]
|
||||
|
||||
metadata = ", ".join(filter(None, [coding, length, non_final]))
|
||||
|
||||
return f"{self.opcode.name} {data} [{metadata}]"
|
||||
|
||||
@classmethod
|
||||
def parse(
|
||||
cls,
|
||||
read_exact: Callable[[int], Generator[None, None, bytes]],
|
||||
*,
|
||||
mask: bool,
|
||||
max_size: int | None = None,
|
||||
extensions: Sequence[extensions.Extension] | None = None,
|
||||
) -> Generator[None, None, Frame]:
|
||||
"""
|
||||
Parse a WebSocket frame.
|
||||
|
||||
This is a generator-based coroutine.
|
||||
|
||||
Args:
|
||||
read_exact: Generator-based coroutine that reads the requested
|
||||
bytes or raises an exception if there isn't enough data.
|
||||
mask: Whether the frame should be masked i.e. whether the read
|
||||
happens on the server side.
|
||||
max_size: Maximum payload size in bytes.
|
||||
extensions: List of extensions, applied in reverse order.
|
||||
|
||||
Raises:
|
||||
EOFError: If the connection is closed without a full WebSocket frame.
|
||||
UnicodeDecodeError: If the frame contains invalid UTF-8.
|
||||
PayloadTooBig: If the frame's payload size exceeds ``max_size``.
|
||||
ProtocolError: If the frame contains incorrect values.
|
||||
|
||||
"""
|
||||
# Read the header.
|
||||
data = yield from read_exact(2)
|
||||
head1, head2 = struct.unpack("!BB", data)
|
||||
|
||||
# While not Pythonic, this is marginally faster than calling bool().
|
||||
fin = True if head1 & 0b10000000 else False
|
||||
rsv1 = True if head1 & 0b01000000 else False
|
||||
rsv2 = True if head1 & 0b00100000 else False
|
||||
rsv3 = True if head1 & 0b00010000 else False
|
||||
|
||||
try:
|
||||
opcode = Opcode(head1 & 0b00001111)
|
||||
except ValueError as exc:
|
||||
raise ProtocolError("invalid opcode") from exc
|
||||
|
||||
if (True if head2 & 0b10000000 else False) != mask:
|
||||
raise ProtocolError("incorrect masking")
|
||||
|
||||
length = head2 & 0b01111111
|
||||
if length == 126:
|
||||
data = yield from read_exact(2)
|
||||
(length,) = struct.unpack("!H", data)
|
||||
elif length == 127:
|
||||
data = yield from read_exact(8)
|
||||
(length,) = struct.unpack("!Q", data)
|
||||
if max_size is not None and length > max_size:
|
||||
raise PayloadTooBig(f"over size limit ({length} > {max_size} bytes)")
|
||||
if mask:
|
||||
mask_bytes = yield from read_exact(4)
|
||||
|
||||
# Read the data.
|
||||
data = yield from read_exact(length)
|
||||
if mask:
|
||||
data = apply_mask(data, mask_bytes)
|
||||
|
||||
frame = cls(opcode, data, fin, rsv1, rsv2, rsv3)
|
||||
|
||||
if extensions is None:
|
||||
extensions = []
|
||||
for extension in reversed(extensions):
|
||||
frame = extension.decode(frame, max_size=max_size)
|
||||
|
||||
frame.check()
|
||||
|
||||
return frame
|
||||
|
||||
def serialize(
|
||||
self,
|
||||
*,
|
||||
mask: bool,
|
||||
extensions: Sequence[extensions.Extension] | None = None,
|
||||
) -> bytes:
|
||||
"""
|
||||
Serialize a WebSocket frame.
|
||||
|
||||
Args:
|
||||
mask: Whether the frame should be masked i.e. whether the write
|
||||
happens on the client side.
|
||||
extensions: List of extensions, applied in order.
|
||||
|
||||
Raises:
|
||||
ProtocolError: If the frame contains incorrect values.
|
||||
|
||||
"""
|
||||
self.check()
|
||||
|
||||
if extensions is None:
|
||||
extensions = []
|
||||
for extension in extensions:
|
||||
self = extension.encode(self)
|
||||
|
||||
output = io.BytesIO()
|
||||
|
||||
# Prepare the header.
|
||||
head1 = (
|
||||
(0b10000000 if self.fin else 0)
|
||||
| (0b01000000 if self.rsv1 else 0)
|
||||
| (0b00100000 if self.rsv2 else 0)
|
||||
| (0b00010000 if self.rsv3 else 0)
|
||||
| self.opcode
|
||||
)
|
||||
|
||||
head2 = 0b10000000 if mask else 0
|
||||
|
||||
length = len(self.data)
|
||||
if length < 126:
|
||||
output.write(struct.pack("!BB", head1, head2 | length))
|
||||
elif length < 65536:
|
||||
output.write(struct.pack("!BBH", head1, head2 | 126, length))
|
||||
else:
|
||||
output.write(struct.pack("!BBQ", head1, head2 | 127, length))
|
||||
|
||||
if mask:
|
||||
mask_bytes = secrets.token_bytes(4)
|
||||
output.write(mask_bytes)
|
||||
|
||||
# Prepare the data.
|
||||
if mask:
|
||||
data = apply_mask(self.data, mask_bytes)
|
||||
else:
|
||||
data = self.data
|
||||
output.write(data)
|
||||
|
||||
return output.getvalue()
|
||||
|
||||
def check(self) -> None:
|
||||
"""
|
||||
Check that reserved bits and opcode have acceptable values.
|
||||
|
||||
Raises:
|
||||
ProtocolError: If a reserved bit or the opcode is invalid.
|
||||
|
||||
"""
|
||||
if self.rsv1 or self.rsv2 or self.rsv3:
|
||||
raise ProtocolError("reserved bits must be 0")
|
||||
|
||||
if self.opcode in CTRL_OPCODES:
|
||||
if len(self.data) > 125:
|
||||
raise ProtocolError("control frame too long")
|
||||
if not self.fin:
|
||||
raise ProtocolError("fragmented control frame")
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Close:
|
||||
"""
|
||||
Code and reason for WebSocket close frames.
|
||||
|
||||
Attributes:
|
||||
code: Close code.
|
||||
reason: Close reason.
|
||||
|
||||
"""
|
||||
|
||||
code: int
|
||||
reason: str
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""
|
||||
Return a human-readable representation of a close code and reason.
|
||||
|
||||
"""
|
||||
if 3000 <= self.code < 4000:
|
||||
explanation = "registered"
|
||||
elif 4000 <= self.code < 5000:
|
||||
explanation = "private use"
|
||||
else:
|
||||
explanation = CLOSE_CODE_EXPLANATIONS.get(self.code, "unknown")
|
||||
result = f"{self.code} ({explanation})"
|
||||
|
||||
if self.reason:
|
||||
result = f"{result} {self.reason}"
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def parse(cls, data: bytes) -> Close:
|
||||
"""
|
||||
Parse the payload of a close frame.
|
||||
|
||||
Args:
|
||||
data: Payload of the close frame.
|
||||
|
||||
Raises:
|
||||
ProtocolError: If data is ill-formed.
|
||||
UnicodeDecodeError: If the reason isn't valid UTF-8.
|
||||
|
||||
"""
|
||||
if len(data) >= 2:
|
||||
(code,) = struct.unpack("!H", data[:2])
|
||||
reason = data[2:].decode()
|
||||
close = cls(code, reason)
|
||||
close.check()
|
||||
return close
|
||||
elif len(data) == 0:
|
||||
return cls(CloseCode.NO_STATUS_RCVD, "")
|
||||
else:
|
||||
raise ProtocolError("close frame too short")
|
||||
|
||||
def serialize(self) -> bytes:
|
||||
"""
|
||||
Serialize the payload of a close frame.
|
||||
|
||||
"""
|
||||
self.check()
|
||||
return struct.pack("!H", self.code) + self.reason.encode()
|
||||
|
||||
def check(self) -> None:
|
||||
"""
|
||||
Check that the close code has a valid value for a close frame.
|
||||
|
||||
Raises:
|
||||
ProtocolError: If the close code is invalid.
|
||||
|
||||
"""
|
||||
if not (self.code in EXTERNAL_CLOSE_CODES or 3000 <= self.code < 5000):
|
||||
raise ProtocolError("invalid status code")
|
||||
|
||||
|
||||
# At the bottom to break import cycles created by type annotations.
|
||||
from . import extensions # noqa: E402
|
@ -1,579 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import binascii
|
||||
import ipaddress
|
||||
import re
|
||||
from typing import Callable, Sequence, TypeVar, cast
|
||||
|
||||
from .exceptions import InvalidHeaderFormat, InvalidHeaderValue
|
||||
from .typing import (
|
||||
ConnectionOption,
|
||||
ExtensionHeader,
|
||||
ExtensionName,
|
||||
ExtensionParameter,
|
||||
Subprotocol,
|
||||
UpgradeProtocol,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"build_host",
|
||||
"parse_connection",
|
||||
"parse_upgrade",
|
||||
"parse_extension",
|
||||
"build_extension",
|
||||
"parse_subprotocol",
|
||||
"build_subprotocol",
|
||||
"validate_subprotocols",
|
||||
"build_www_authenticate_basic",
|
||||
"parse_authorization_basic",
|
||||
"build_authorization_basic",
|
||||
]
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def build_host(host: str, port: int, secure: bool) -> str:
|
||||
"""
|
||||
Build a ``Host`` header.
|
||||
|
||||
"""
|
||||
# https://datatracker.ietf.org/doc/html/rfc3986#section-3.2.2
|
||||
# IPv6 addresses must be enclosed in brackets.
|
||||
try:
|
||||
address = ipaddress.ip_address(host)
|
||||
except ValueError:
|
||||
# host is a hostname
|
||||
pass
|
||||
else:
|
||||
# host is an IP address
|
||||
if address.version == 6:
|
||||
host = f"[{host}]"
|
||||
|
||||
if port != (443 if secure else 80):
|
||||
host = f"{host}:{port}"
|
||||
|
||||
return host
|
||||
|
||||
|
||||
# To avoid a dependency on a parsing library, we implement manually the ABNF
|
||||
# described in https://datatracker.ietf.org/doc/html/rfc6455#section-9.1 and
|
||||
# https://datatracker.ietf.org/doc/html/rfc7230#appendix-B.
|
||||
|
||||
|
||||
def peek_ahead(header: str, pos: int) -> str | None:
|
||||
"""
|
||||
Return the next character from ``header`` at the given position.
|
||||
|
||||
Return :obj:`None` at the end of ``header``.
|
||||
|
||||
We never need to peek more than one character ahead.
|
||||
|
||||
"""
|
||||
return None if pos == len(header) else header[pos]
|
||||
|
||||
|
||||
_OWS_re = re.compile(r"[\t ]*")
|
||||
|
||||
|
||||
def parse_OWS(header: str, pos: int) -> int:
|
||||
"""
|
||||
Parse optional whitespace from ``header`` at the given position.
|
||||
|
||||
Return the new position.
|
||||
|
||||
The whitespace itself isn't returned because it isn't significant.
|
||||
|
||||
"""
|
||||
# There's always a match, possibly empty, whose content doesn't matter.
|
||||
match = _OWS_re.match(header, pos)
|
||||
assert match is not None
|
||||
return match.end()
|
||||
|
||||
|
||||
_token_re = re.compile(r"[-!#$%&\'*+.^_`|~0-9a-zA-Z]+")
|
||||
|
||||
|
||||
def parse_token(header: str, pos: int, header_name: str) -> tuple[str, int]:
|
||||
"""
|
||||
Parse a token from ``header`` at the given position.
|
||||
|
||||
Return the token value and the new position.
|
||||
|
||||
Raises:
|
||||
InvalidHeaderFormat: On invalid inputs.
|
||||
|
||||
"""
|
||||
match = _token_re.match(header, pos)
|
||||
if match is None:
|
||||
raise InvalidHeaderFormat(header_name, "expected token", header, pos)
|
||||
return match.group(), match.end()
|
||||
|
||||
|
||||
_quoted_string_re = re.compile(
|
||||
r'"(?:[\x09\x20-\x21\x23-\x5b\x5d-\x7e]|\\[\x09\x20-\x7e\x80-\xff])*"'
|
||||
)
|
||||
|
||||
|
||||
_unquote_re = re.compile(r"\\([\x09\x20-\x7e\x80-\xff])")
|
||||
|
||||
|
||||
def parse_quoted_string(header: str, pos: int, header_name: str) -> tuple[str, int]:
|
||||
"""
|
||||
Parse a quoted string from ``header`` at the given position.
|
||||
|
||||
Return the unquoted value and the new position.
|
||||
|
||||
Raises:
|
||||
InvalidHeaderFormat: On invalid inputs.
|
||||
|
||||
"""
|
||||
match = _quoted_string_re.match(header, pos)
|
||||
if match is None:
|
||||
raise InvalidHeaderFormat(header_name, "expected quoted string", header, pos)
|
||||
return _unquote_re.sub(r"\1", match.group()[1:-1]), match.end()
|
||||
|
||||
|
||||
_quotable_re = re.compile(r"[\x09\x20-\x7e\x80-\xff]*")
|
||||
|
||||
|
||||
_quote_re = re.compile(r"([\x22\x5c])")
|
||||
|
||||
|
||||
def build_quoted_string(value: str) -> str:
|
||||
"""
|
||||
Format ``value`` as a quoted string.
|
||||
|
||||
This is the reverse of :func:`parse_quoted_string`.
|
||||
|
||||
"""
|
||||
match = _quotable_re.fullmatch(value)
|
||||
if match is None:
|
||||
raise ValueError("invalid characters for quoted-string encoding")
|
||||
return '"' + _quote_re.sub(r"\\\1", value) + '"'
|
||||
|
||||
|
||||
def parse_list(
|
||||
parse_item: Callable[[str, int, str], tuple[T, int]],
|
||||
header: str,
|
||||
pos: int,
|
||||
header_name: str,
|
||||
) -> list[T]:
|
||||
"""
|
||||
Parse a comma-separated list from ``header`` at the given position.
|
||||
|
||||
This is appropriate for parsing values with the following grammar:
|
||||
|
||||
1#item
|
||||
|
||||
``parse_item`` parses one item.
|
||||
|
||||
``header`` is assumed not to start or end with whitespace.
|
||||
|
||||
(This function is designed for parsing an entire header value and
|
||||
:func:`~websockets.http.read_headers` strips whitespace from values.)
|
||||
|
||||
Return a list of items.
|
||||
|
||||
Raises:
|
||||
InvalidHeaderFormat: On invalid inputs.
|
||||
|
||||
"""
|
||||
# Per https://datatracker.ietf.org/doc/html/rfc7230#section-7, "a recipient
|
||||
# MUST parse and ignore a reasonable number of empty list elements";
|
||||
# hence while loops that remove extra delimiters.
|
||||
|
||||
# Remove extra delimiters before the first item.
|
||||
while peek_ahead(header, pos) == ",":
|
||||
pos = parse_OWS(header, pos + 1)
|
||||
|
||||
items = []
|
||||
while True:
|
||||
# Loop invariant: a item starts at pos in header.
|
||||
item, pos = parse_item(header, pos, header_name)
|
||||
items.append(item)
|
||||
pos = parse_OWS(header, pos)
|
||||
|
||||
# We may have reached the end of the header.
|
||||
if pos == len(header):
|
||||
break
|
||||
|
||||
# There must be a delimiter after each element except the last one.
|
||||
if peek_ahead(header, pos) == ",":
|
||||
pos = parse_OWS(header, pos + 1)
|
||||
else:
|
||||
raise InvalidHeaderFormat(header_name, "expected comma", header, pos)
|
||||
|
||||
# Remove extra delimiters before the next item.
|
||||
while peek_ahead(header, pos) == ",":
|
||||
pos = parse_OWS(header, pos + 1)
|
||||
|
||||
# We may have reached the end of the header.
|
||||
if pos == len(header):
|
||||
break
|
||||
|
||||
# Since we only advance in the header by one character with peek_ahead()
|
||||
# or with the end position of a regex match, we can't overshoot the end.
|
||||
assert pos == len(header)
|
||||
|
||||
return items
|
||||
|
||||
|
||||
def parse_connection_option(
|
||||
header: str, pos: int, header_name: str
|
||||
) -> tuple[ConnectionOption, int]:
|
||||
"""
|
||||
Parse a Connection option from ``header`` at the given position.
|
||||
|
||||
Return the protocol value and the new position.
|
||||
|
||||
Raises:
|
||||
InvalidHeaderFormat: On invalid inputs.
|
||||
|
||||
"""
|
||||
item, pos = parse_token(header, pos, header_name)
|
||||
return cast(ConnectionOption, item), pos
|
||||
|
||||
|
||||
def parse_connection(header: str) -> list[ConnectionOption]:
|
||||
"""
|
||||
Parse a ``Connection`` header.
|
||||
|
||||
Return a list of HTTP connection options.
|
||||
|
||||
Args
|
||||
header: value of the ``Connection`` header.
|
||||
|
||||
Raises:
|
||||
InvalidHeaderFormat: On invalid inputs.
|
||||
|
||||
"""
|
||||
return parse_list(parse_connection_option, header, 0, "Connection")
|
||||
|
||||
|
||||
_protocol_re = re.compile(
|
||||
r"[-!#$%&\'*+.^_`|~0-9a-zA-Z]+(?:/[-!#$%&\'*+.^_`|~0-9a-zA-Z]+)?"
|
||||
)
|
||||
|
||||
|
||||
def parse_upgrade_protocol(
|
||||
header: str, pos: int, header_name: str
|
||||
) -> tuple[UpgradeProtocol, int]:
|
||||
"""
|
||||
Parse an Upgrade protocol from ``header`` at the given position.
|
||||
|
||||
Return the protocol value and the new position.
|
||||
|
||||
Raises:
|
||||
InvalidHeaderFormat: On invalid inputs.
|
||||
|
||||
"""
|
||||
match = _protocol_re.match(header, pos)
|
||||
if match is None:
|
||||
raise InvalidHeaderFormat(header_name, "expected protocol", header, pos)
|
||||
return cast(UpgradeProtocol, match.group()), match.end()
|
||||
|
||||
|
||||
def parse_upgrade(header: str) -> list[UpgradeProtocol]:
|
||||
"""
|
||||
Parse an ``Upgrade`` header.
|
||||
|
||||
Return a list of HTTP protocols.
|
||||
|
||||
Args:
|
||||
header: Value of the ``Upgrade`` header.
|
||||
|
||||
Raises:
|
||||
InvalidHeaderFormat: On invalid inputs.
|
||||
|
||||
"""
|
||||
return parse_list(parse_upgrade_protocol, header, 0, "Upgrade")
|
||||
|
||||
|
||||
def parse_extension_item_param(
|
||||
header: str, pos: int, header_name: str
|
||||
) -> tuple[ExtensionParameter, int]:
|
||||
"""
|
||||
Parse a single extension parameter from ``header`` at the given position.
|
||||
|
||||
Return a ``(name, value)`` pair and the new position.
|
||||
|
||||
Raises:
|
||||
InvalidHeaderFormat: On invalid inputs.
|
||||
|
||||
"""
|
||||
# Extract parameter name.
|
||||
name, pos = parse_token(header, pos, header_name)
|
||||
pos = parse_OWS(header, pos)
|
||||
# Extract parameter value, if there is one.
|
||||
value: str | None = None
|
||||
if peek_ahead(header, pos) == "=":
|
||||
pos = parse_OWS(header, pos + 1)
|
||||
if peek_ahead(header, pos) == '"':
|
||||
pos_before = pos # for proper error reporting below
|
||||
value, pos = parse_quoted_string(header, pos, header_name)
|
||||
# https://datatracker.ietf.org/doc/html/rfc6455#section-9.1 says:
|
||||
# the value after quoted-string unescaping MUST conform to
|
||||
# the 'token' ABNF.
|
||||
if _token_re.fullmatch(value) is None:
|
||||
raise InvalidHeaderFormat(
|
||||
header_name, "invalid quoted header content", header, pos_before
|
||||
)
|
||||
else:
|
||||
value, pos = parse_token(header, pos, header_name)
|
||||
pos = parse_OWS(header, pos)
|
||||
|
||||
return (name, value), pos
|
||||
|
||||
|
||||
def parse_extension_item(
|
||||
header: str, pos: int, header_name: str
|
||||
) -> tuple[ExtensionHeader, int]:
|
||||
"""
|
||||
Parse an extension definition from ``header`` at the given position.
|
||||
|
||||
Return an ``(extension name, parameters)`` pair, where ``parameters`` is a
|
||||
list of ``(name, value)`` pairs, and the new position.
|
||||
|
||||
Raises:
|
||||
InvalidHeaderFormat: On invalid inputs.
|
||||
|
||||
"""
|
||||
# Extract extension name.
|
||||
name, pos = parse_token(header, pos, header_name)
|
||||
pos = parse_OWS(header, pos)
|
||||
# Extract all parameters.
|
||||
parameters = []
|
||||
while peek_ahead(header, pos) == ";":
|
||||
pos = parse_OWS(header, pos + 1)
|
||||
parameter, pos = parse_extension_item_param(header, pos, header_name)
|
||||
parameters.append(parameter)
|
||||
return (cast(ExtensionName, name), parameters), pos
|
||||
|
||||
|
||||
def parse_extension(header: str) -> list[ExtensionHeader]:
|
||||
"""
|
||||
Parse a ``Sec-WebSocket-Extensions`` header.
|
||||
|
||||
Return a list of WebSocket extensions and their parameters in this format::
|
||||
|
||||
[
|
||||
(
|
||||
'extension name',
|
||||
[
|
||||
('parameter name', 'parameter value'),
|
||||
....
|
||||
]
|
||||
),
|
||||
...
|
||||
]
|
||||
|
||||
Parameter values are :obj:`None` when no value is provided.
|
||||
|
||||
Raises:
|
||||
InvalidHeaderFormat: On invalid inputs.
|
||||
|
||||
"""
|
||||
return parse_list(parse_extension_item, header, 0, "Sec-WebSocket-Extensions")
|
||||
|
||||
|
||||
parse_extension_list = parse_extension # alias for backwards compatibility
|
||||
|
||||
|
||||
def build_extension_item(
|
||||
name: ExtensionName, parameters: list[ExtensionParameter]
|
||||
) -> str:
|
||||
"""
|
||||
Build an extension definition.
|
||||
|
||||
This is the reverse of :func:`parse_extension_item`.
|
||||
|
||||
"""
|
||||
return "; ".join(
|
||||
[cast(str, name)]
|
||||
+ [
|
||||
# Quoted strings aren't necessary because values are always tokens.
|
||||
name if value is None else f"{name}={value}"
|
||||
for name, value in parameters
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def build_extension(extensions: Sequence[ExtensionHeader]) -> str:
|
||||
"""
|
||||
Build a ``Sec-WebSocket-Extensions`` header.
|
||||
|
||||
This is the reverse of :func:`parse_extension`.
|
||||
|
||||
"""
|
||||
return ", ".join(
|
||||
build_extension_item(name, parameters) for name, parameters in extensions
|
||||
)
|
||||
|
||||
|
||||
build_extension_list = build_extension # alias for backwards compatibility
|
||||
|
||||
|
||||
def parse_subprotocol_item(
|
||||
header: str, pos: int, header_name: str
|
||||
) -> tuple[Subprotocol, int]:
|
||||
"""
|
||||
Parse a subprotocol from ``header`` at the given position.
|
||||
|
||||
Return the subprotocol value and the new position.
|
||||
|
||||
Raises:
|
||||
InvalidHeaderFormat: On invalid inputs.
|
||||
|
||||
"""
|
||||
item, pos = parse_token(header, pos, header_name)
|
||||
return cast(Subprotocol, item), pos
|
||||
|
||||
|
||||
def parse_subprotocol(header: str) -> list[Subprotocol]:
|
||||
"""
|
||||
Parse a ``Sec-WebSocket-Protocol`` header.
|
||||
|
||||
Return a list of WebSocket subprotocols.
|
||||
|
||||
Raises:
|
||||
InvalidHeaderFormat: On invalid inputs.
|
||||
|
||||
"""
|
||||
return parse_list(parse_subprotocol_item, header, 0, "Sec-WebSocket-Protocol")
|
||||
|
||||
|
||||
parse_subprotocol_list = parse_subprotocol # alias for backwards compatibility
|
||||
|
||||
|
||||
def build_subprotocol(subprotocols: Sequence[Subprotocol]) -> str:
|
||||
"""
|
||||
Build a ``Sec-WebSocket-Protocol`` header.
|
||||
|
||||
This is the reverse of :func:`parse_subprotocol`.
|
||||
|
||||
"""
|
||||
return ", ".join(subprotocols)
|
||||
|
||||
|
||||
build_subprotocol_list = build_subprotocol # alias for backwards compatibility
|
||||
|
||||
|
||||
def validate_subprotocols(subprotocols: Sequence[Subprotocol]) -> None:
|
||||
"""
|
||||
Validate that ``subprotocols`` is suitable for :func:`build_subprotocol`.
|
||||
|
||||
"""
|
||||
if not isinstance(subprotocols, Sequence):
|
||||
raise TypeError("subprotocols must be a list")
|
||||
if isinstance(subprotocols, str):
|
||||
raise TypeError("subprotocols must be a list, not a str")
|
||||
for subprotocol in subprotocols:
|
||||
if not _token_re.fullmatch(subprotocol):
|
||||
raise ValueError(f"invalid subprotocol: {subprotocol}")
|
||||
|
||||
|
||||
def build_www_authenticate_basic(realm: str) -> str:
|
||||
"""
|
||||
Build a ``WWW-Authenticate`` header for HTTP Basic Auth.
|
||||
|
||||
Args:
|
||||
realm: Identifier of the protection space.
|
||||
|
||||
"""
|
||||
# https://datatracker.ietf.org/doc/html/rfc7617#section-2
|
||||
realm = build_quoted_string(realm)
|
||||
charset = build_quoted_string("UTF-8")
|
||||
return f"Basic realm={realm}, charset={charset}"
|
||||
|
||||
|
||||
_token68_re = re.compile(r"[A-Za-z0-9-._~+/]+=*")
|
||||
|
||||
|
||||
def parse_token68(header: str, pos: int, header_name: str) -> tuple[str, int]:
|
||||
"""
|
||||
Parse a token68 from ``header`` at the given position.
|
||||
|
||||
Return the token value and the new position.
|
||||
|
||||
Raises:
|
||||
InvalidHeaderFormat: On invalid inputs.
|
||||
|
||||
"""
|
||||
match = _token68_re.match(header, pos)
|
||||
if match is None:
|
||||
raise InvalidHeaderFormat(header_name, "expected token68", header, pos)
|
||||
return match.group(), match.end()
|
||||
|
||||
|
||||
def parse_end(header: str, pos: int, header_name: str) -> None:
|
||||
"""
|
||||
Check that parsing reached the end of header.
|
||||
|
||||
"""
|
||||
if pos < len(header):
|
||||
raise InvalidHeaderFormat(header_name, "trailing data", header, pos)
|
||||
|
||||
|
||||
def parse_authorization_basic(header: str) -> tuple[str, str]:
|
||||
"""
|
||||
Parse an ``Authorization`` header for HTTP Basic Auth.
|
||||
|
||||
Return a ``(username, password)`` tuple.
|
||||
|
||||
Args:
|
||||
header: Value of the ``Authorization`` header.
|
||||
|
||||
Raises:
|
||||
InvalidHeaderFormat: On invalid inputs.
|
||||
InvalidHeaderValue: On unsupported inputs.
|
||||
|
||||
"""
|
||||
# https://datatracker.ietf.org/doc/html/rfc7235#section-2.1
|
||||
# https://datatracker.ietf.org/doc/html/rfc7617#section-2
|
||||
scheme, pos = parse_token(header, 0, "Authorization")
|
||||
if scheme.lower() != "basic":
|
||||
raise InvalidHeaderValue(
|
||||
"Authorization",
|
||||
f"unsupported scheme: {scheme}",
|
||||
)
|
||||
if peek_ahead(header, pos) != " ":
|
||||
raise InvalidHeaderFormat(
|
||||
"Authorization", "expected space after scheme", header, pos
|
||||
)
|
||||
pos += 1
|
||||
basic_credentials, pos = parse_token68(header, pos, "Authorization")
|
||||
parse_end(header, pos, "Authorization")
|
||||
|
||||
try:
|
||||
user_pass = base64.b64decode(basic_credentials.encode()).decode()
|
||||
except binascii.Error:
|
||||
raise InvalidHeaderValue(
|
||||
"Authorization",
|
||||
"expected base64-encoded credentials",
|
||||
) from None
|
||||
try:
|
||||
username, password = user_pass.split(":", 1)
|
||||
except ValueError:
|
||||
raise InvalidHeaderValue(
|
||||
"Authorization",
|
||||
"expected username:password credentials",
|
||||
) from None
|
||||
|
||||
return username, password
|
||||
|
||||
|
||||
def build_authorization_basic(username: str, password: str) -> str:
|
||||
"""
|
||||
Build an ``Authorization`` header for HTTP Basic Auth.
|
||||
|
||||
This is the reverse of :func:`parse_authorization_basic`.
|
||||
|
||||
"""
|
||||
# https://datatracker.ietf.org/doc/html/rfc7617#section-2
|
||||
assert ":" not in username
|
||||
user_pass = f"{username}:{password}"
|
||||
basic_credentials = base64.b64encode(user_pass.encode()).decode()
|
||||
return "Basic " + basic_credentials
|
@ -1,15 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
|
||||
from .datastructures import Headers, MultipleValuesError # noqa: F401
|
||||
from .legacy.http import read_request, read_response # noqa: F401
|
||||
|
||||
|
||||
warnings.warn( # deprecated in 9.0 - 2021-09-01
|
||||
"Headers and MultipleValuesError were moved "
|
||||
"from websockets.http to websockets.datastructures"
|
||||
"and read_request and read_response were moved "
|
||||
"from websockets.http to websockets.legacy.http",
|
||||
DeprecationWarning,
|
||||
)
|
@ -1,385 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import warnings
|
||||
from typing import Callable, Generator
|
||||
|
||||
from .datastructures import Headers
|
||||
from .exceptions import SecurityError
|
||||
from .version import version as websockets_version
|
||||
|
||||
|
||||
__all__ = ["SERVER", "USER_AGENT", "Request", "Response"]
|
||||
|
||||
|
||||
PYTHON_VERSION = "{}.{}".format(*sys.version_info)
|
||||
|
||||
# User-Agent header for HTTP requests.
|
||||
USER_AGENT = os.environ.get(
|
||||
"WEBSOCKETS_USER_AGENT",
|
||||
f"Python/{PYTHON_VERSION} websockets/{websockets_version}",
|
||||
)
|
||||
|
||||
# Server header for HTTP responses.
|
||||
SERVER = os.environ.get(
|
||||
"WEBSOCKETS_SERVER",
|
||||
f"Python/{PYTHON_VERSION} websockets/{websockets_version}",
|
||||
)
|
||||
|
||||
# Maximum total size of headers is around 128 * 8 KiB = 1 MiB.
|
||||
MAX_NUM_HEADERS = int(os.environ.get("WEBSOCKETS_MAX_NUM_HEADERS", "128"))
|
||||
|
||||
# Limit request line and header lines. 8KiB is the most common default
|
||||
# configuration of popular HTTP servers.
|
||||
MAX_LINE_LENGTH = int(os.environ.get("WEBSOCKETS_MAX_LINE_LENGTH", "8192"))
|
||||
|
||||
# Support for HTTP response bodies is intended to read an error message
|
||||
# returned by a server. It isn't designed to perform large file transfers.
|
||||
MAX_BODY_SIZE = int(os.environ.get("WEBSOCKETS_MAX_BODY_SIZE", "1_048_576")) # 1 MiB
|
||||
|
||||
|
||||
def d(value: bytes) -> str:
|
||||
"""
|
||||
Decode a bytestring for interpolating into an error message.
|
||||
|
||||
"""
|
||||
return value.decode(errors="backslashreplace")
|
||||
|
||||
|
||||
# See https://datatracker.ietf.org/doc/html/rfc7230#appendix-B.
|
||||
|
||||
# Regex for validating header names.
|
||||
|
||||
_token_re = re.compile(rb"[-!#$%&\'*+.^_`|~0-9a-zA-Z]+")
|
||||
|
||||
# Regex for validating header values.
|
||||
|
||||
# We don't attempt to support obsolete line folding.
|
||||
|
||||
# Include HTAB (\x09), SP (\x20), VCHAR (\x21-\x7e), obs-text (\x80-\xff).
|
||||
|
||||
# The ABNF is complicated because it attempts to express that optional
|
||||
# whitespace is ignored. We strip whitespace and don't revalidate that.
|
||||
|
||||
# See also https://www.rfc-editor.org/errata_search.php?rfc=7230&eid=4189
|
||||
|
||||
_value_re = re.compile(rb"[\x09\x20-\x7e\x80-\xff]*")
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Request:
|
||||
"""
|
||||
WebSocket handshake request.
|
||||
|
||||
Attributes:
|
||||
path: Request path, including optional query.
|
||||
headers: Request headers.
|
||||
"""
|
||||
|
||||
path: str
|
||||
headers: Headers
|
||||
# body isn't useful is the context of this library.
|
||||
|
||||
_exception: Exception | None = None
|
||||
|
||||
@property
|
||||
def exception(self) -> Exception | None: # pragma: no cover
|
||||
warnings.warn( # deprecated in 10.3 - 2022-04-17
|
||||
"Request.exception is deprecated; "
|
||||
"use ServerProtocol.handshake_exc instead",
|
||||
DeprecationWarning,
|
||||
)
|
||||
return self._exception
|
||||
|
||||
@classmethod
|
||||
def parse(
|
||||
cls,
|
||||
read_line: Callable[[int], Generator[None, None, bytes]],
|
||||
) -> Generator[None, None, Request]:
|
||||
"""
|
||||
Parse a WebSocket handshake request.
|
||||
|
||||
This is a generator-based coroutine.
|
||||
|
||||
The request path isn't URL-decoded or validated in any way.
|
||||
|
||||
The request path and headers are expected to contain only ASCII
|
||||
characters. Other characters are represented with surrogate escapes.
|
||||
|
||||
:meth:`parse` doesn't attempt to read the request body because
|
||||
WebSocket handshake requests don't have one. If the request contains a
|
||||
body, it may be read from the data stream after :meth:`parse` returns.
|
||||
|
||||
Args:
|
||||
read_line: Generator-based coroutine that reads a LF-terminated
|
||||
line or raises an exception if there isn't enough data
|
||||
|
||||
Raises:
|
||||
EOFError: If the connection is closed without a full HTTP request.
|
||||
SecurityError: If the request exceeds a security limit.
|
||||
ValueError: If the request isn't well formatted.
|
||||
|
||||
"""
|
||||
# https://datatracker.ietf.org/doc/html/rfc7230#section-3.1.1
|
||||
|
||||
# Parsing is simple because fixed values are expected for method and
|
||||
# version and because path isn't checked. Since WebSocket software tends
|
||||
# to implement HTTP/1.1 strictly, there's little need for lenient parsing.
|
||||
|
||||
try:
|
||||
request_line = yield from parse_line(read_line)
|
||||
except EOFError as exc:
|
||||
raise EOFError("connection closed while reading HTTP request line") from exc
|
||||
|
||||
try:
|
||||
method, raw_path, protocol = request_line.split(b" ", 2)
|
||||
except ValueError: # not enough values to unpack (expected 3, got 1-2)
|
||||
raise ValueError(f"invalid HTTP request line: {d(request_line)}") from None
|
||||
if protocol != b"HTTP/1.1":
|
||||
raise ValueError(
|
||||
f"unsupported protocol; expected HTTP/1.1: {d(request_line)}"
|
||||
)
|
||||
if method != b"GET":
|
||||
raise ValueError(f"unsupported HTTP method; expected GET; got {d(method)}")
|
||||
path = raw_path.decode("ascii", "surrogateescape")
|
||||
|
||||
headers = yield from parse_headers(read_line)
|
||||
|
||||
# https://datatracker.ietf.org/doc/html/rfc7230#section-3.3.3
|
||||
|
||||
if "Transfer-Encoding" in headers:
|
||||
raise NotImplementedError("transfer codings aren't supported")
|
||||
|
||||
if "Content-Length" in headers:
|
||||
raise ValueError("unsupported request body")
|
||||
|
||||
return cls(path, headers)
|
||||
|
||||
def serialize(self) -> bytes:
|
||||
"""
|
||||
Serialize a WebSocket handshake request.
|
||||
|
||||
"""
|
||||
# Since the request line and headers only contain ASCII characters,
|
||||
# we can keep this simple.
|
||||
request = f"GET {self.path} HTTP/1.1\r\n".encode()
|
||||
request += self.headers.serialize()
|
||||
return request
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Response:
|
||||
"""
|
||||
WebSocket handshake response.
|
||||
|
||||
Attributes:
|
||||
status_code: Response code.
|
||||
reason_phrase: Response reason.
|
||||
headers: Response headers.
|
||||
body: Response body, if any.
|
||||
|
||||
"""
|
||||
|
||||
status_code: int
|
||||
reason_phrase: str
|
||||
headers: Headers
|
||||
body: bytes | None = None
|
||||
|
||||
_exception: Exception | None = None
|
||||
|
||||
@property
|
||||
def exception(self) -> Exception | None: # pragma: no cover
|
||||
warnings.warn( # deprecated in 10.3 - 2022-04-17
|
||||
"Response.exception is deprecated; "
|
||||
"use ClientProtocol.handshake_exc instead",
|
||||
DeprecationWarning,
|
||||
)
|
||||
return self._exception
|
||||
|
||||
@classmethod
|
||||
def parse(
|
||||
cls,
|
||||
read_line: Callable[[int], Generator[None, None, bytes]],
|
||||
read_exact: Callable[[int], Generator[None, None, bytes]],
|
||||
read_to_eof: Callable[[int], Generator[None, None, bytes]],
|
||||
) -> Generator[None, None, Response]:
|
||||
"""
|
||||
Parse a WebSocket handshake response.
|
||||
|
||||
This is a generator-based coroutine.
|
||||
|
||||
The reason phrase and headers are expected to contain only ASCII
|
||||
characters. Other characters are represented with surrogate escapes.
|
||||
|
||||
Args:
|
||||
read_line: Generator-based coroutine that reads a LF-terminated
|
||||
line or raises an exception if there isn't enough data.
|
||||
read_exact: Generator-based coroutine that reads the requested
|
||||
bytes or raises an exception if there isn't enough data.
|
||||
read_to_eof: Generator-based coroutine that reads until the end
|
||||
of the stream.
|
||||
|
||||
Raises:
|
||||
EOFError: If the connection is closed without a full HTTP response.
|
||||
SecurityError: If the response exceeds a security limit.
|
||||
LookupError: If the response isn't well formatted.
|
||||
ValueError: If the response isn't well formatted.
|
||||
|
||||
"""
|
||||
# https://datatracker.ietf.org/doc/html/rfc7230#section-3.1.2
|
||||
|
||||
try:
|
||||
status_line = yield from parse_line(read_line)
|
||||
except EOFError as exc:
|
||||
raise EOFError("connection closed while reading HTTP status line") from exc
|
||||
|
||||
try:
|
||||
protocol, raw_status_code, raw_reason = status_line.split(b" ", 2)
|
||||
except ValueError: # not enough values to unpack (expected 3, got 1-2)
|
||||
raise ValueError(f"invalid HTTP status line: {d(status_line)}") from None
|
||||
if protocol != b"HTTP/1.1":
|
||||
raise ValueError(
|
||||
f"unsupported protocol; expected HTTP/1.1: {d(status_line)}"
|
||||
)
|
||||
try:
|
||||
status_code = int(raw_status_code)
|
||||
except ValueError: # invalid literal for int() with base 10
|
||||
raise ValueError(
|
||||
f"invalid status code; expected integer; got {d(raw_status_code)}"
|
||||
) from None
|
||||
if not 100 <= status_code < 600:
|
||||
raise ValueError(
|
||||
f"invalid status code; expected 100–599; got {d(raw_status_code)}"
|
||||
)
|
||||
if not _value_re.fullmatch(raw_reason):
|
||||
raise ValueError(f"invalid HTTP reason phrase: {d(raw_reason)}")
|
||||
reason = raw_reason.decode("ascii", "surrogateescape")
|
||||
|
||||
headers = yield from parse_headers(read_line)
|
||||
|
||||
# https://datatracker.ietf.org/doc/html/rfc7230#section-3.3.3
|
||||
|
||||
if "Transfer-Encoding" in headers:
|
||||
raise NotImplementedError("transfer codings aren't supported")
|
||||
|
||||
# Since websockets only does GET requests (no HEAD, no CONNECT), all
|
||||
# responses except 1xx, 204, and 304 include a message body.
|
||||
if 100 <= status_code < 200 or status_code == 204 or status_code == 304:
|
||||
body = None
|
||||
else:
|
||||
content_length: int | None
|
||||
try:
|
||||
# MultipleValuesError is sufficiently unlikely that we don't
|
||||
# attempt to handle it. Instead we document that its parent
|
||||
# class, LookupError, may be raised.
|
||||
raw_content_length = headers["Content-Length"]
|
||||
except KeyError:
|
||||
content_length = None
|
||||
else:
|
||||
content_length = int(raw_content_length)
|
||||
|
||||
if content_length is None:
|
||||
try:
|
||||
body = yield from read_to_eof(MAX_BODY_SIZE)
|
||||
except RuntimeError:
|
||||
raise SecurityError(f"body too large: over {MAX_BODY_SIZE} bytes")
|
||||
elif content_length > MAX_BODY_SIZE:
|
||||
raise SecurityError(f"body too large: {content_length} bytes")
|
||||
else:
|
||||
body = yield from read_exact(content_length)
|
||||
|
||||
return cls(status_code, reason, headers, body)
|
||||
|
||||
def serialize(self) -> bytes:
|
||||
"""
|
||||
Serialize a WebSocket handshake response.
|
||||
|
||||
"""
|
||||
# Since the status line and headers only contain ASCII characters,
|
||||
# we can keep this simple.
|
||||
response = f"HTTP/1.1 {self.status_code} {self.reason_phrase}\r\n".encode()
|
||||
response += self.headers.serialize()
|
||||
if self.body is not None:
|
||||
response += self.body
|
||||
return response
|
||||
|
||||
|
||||
def parse_headers(
|
||||
read_line: Callable[[int], Generator[None, None, bytes]],
|
||||
) -> Generator[None, None, Headers]:
|
||||
"""
|
||||
Parse HTTP headers.
|
||||
|
||||
Non-ASCII characters are represented with surrogate escapes.
|
||||
|
||||
Args:
|
||||
read_line: Generator-based coroutine that reads a LF-terminated line
|
||||
or raises an exception if there isn't enough data.
|
||||
|
||||
Raises:
|
||||
EOFError: If the connection is closed without complete headers.
|
||||
SecurityError: If the request exceeds a security limit.
|
||||
ValueError: If the request isn't well formatted.
|
||||
|
||||
"""
|
||||
# https://datatracker.ietf.org/doc/html/rfc7230#section-3.2
|
||||
|
||||
# We don't attempt to support obsolete line folding.
|
||||
|
||||
headers = Headers()
|
||||
for _ in range(MAX_NUM_HEADERS + 1):
|
||||
try:
|
||||
line = yield from parse_line(read_line)
|
||||
except EOFError as exc:
|
||||
raise EOFError("connection closed while reading HTTP headers") from exc
|
||||
if line == b"":
|
||||
break
|
||||
|
||||
try:
|
||||
raw_name, raw_value = line.split(b":", 1)
|
||||
except ValueError: # not enough values to unpack (expected 2, got 1)
|
||||
raise ValueError(f"invalid HTTP header line: {d(line)}") from None
|
||||
if not _token_re.fullmatch(raw_name):
|
||||
raise ValueError(f"invalid HTTP header name: {d(raw_name)}")
|
||||
raw_value = raw_value.strip(b" \t")
|
||||
if not _value_re.fullmatch(raw_value):
|
||||
raise ValueError(f"invalid HTTP header value: {d(raw_value)}")
|
||||
|
||||
name = raw_name.decode("ascii") # guaranteed to be ASCII at this point
|
||||
value = raw_value.decode("ascii", "surrogateescape")
|
||||
headers[name] = value
|
||||
|
||||
else:
|
||||
raise SecurityError("too many HTTP headers")
|
||||
|
||||
return headers
|
||||
|
||||
|
||||
def parse_line(
|
||||
read_line: Callable[[int], Generator[None, None, bytes]],
|
||||
) -> Generator[None, None, bytes]:
|
||||
"""
|
||||
Parse a single line.
|
||||
|
||||
CRLF is stripped from the return value.
|
||||
|
||||
Args:
|
||||
read_line: Generator-based coroutine that reads a LF-terminated line
|
||||
or raises an exception if there isn't enough data.
|
||||
|
||||
Raises:
|
||||
EOFError: If the connection is closed without a CRLF.
|
||||
SecurityError: If the response exceeds a security limit.
|
||||
|
||||
"""
|
||||
try:
|
||||
line = yield from read_line(MAX_LINE_LENGTH)
|
||||
except RuntimeError:
|
||||
raise SecurityError("line too long")
|
||||
# Not mandatory but safe - https://datatracker.ietf.org/doc/html/rfc7230#section-3.5
|
||||
if not line.endswith(b"\r\n"):
|
||||
raise EOFError("line without CRLF")
|
||||
return line[:-2]
|
@ -1,99 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from typing import Any, Iterable
|
||||
|
||||
|
||||
__all__ = ["lazy_import"]
|
||||
|
||||
|
||||
def import_name(name: str, source: str, namespace: dict[str, Any]) -> Any:
|
||||
"""
|
||||
Import ``name`` from ``source`` in ``namespace``.
|
||||
|
||||
There are two use cases:
|
||||
|
||||
- ``name`` is an object defined in ``source``;
|
||||
- ``name`` is a submodule of ``source``.
|
||||
|
||||
Neither :func:`__import__` nor :func:`~importlib.import_module` does
|
||||
exactly this. :func:`__import__` is closer to the intended behavior.
|
||||
|
||||
"""
|
||||
level = 0
|
||||
while source[level] == ".":
|
||||
level += 1
|
||||
assert level < len(source), "importing from parent isn't supported"
|
||||
module = __import__(source[level:], namespace, None, [name], level)
|
||||
return getattr(module, name)
|
||||
|
||||
|
||||
def lazy_import(
|
||||
namespace: dict[str, Any],
|
||||
aliases: dict[str, str] | None = None,
|
||||
deprecated_aliases: dict[str, str] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Provide lazy, module-level imports.
|
||||
|
||||
Typical use::
|
||||
|
||||
__getattr__, __dir__ = lazy_import(
|
||||
globals(),
|
||||
aliases={
|
||||
"<name>": "<source module>",
|
||||
...
|
||||
},
|
||||
deprecated_aliases={
|
||||
...,
|
||||
}
|
||||
)
|
||||
|
||||
This function defines ``__getattr__`` and ``__dir__`` per :pep:`562`.
|
||||
|
||||
"""
|
||||
if aliases is None:
|
||||
aliases = {}
|
||||
if deprecated_aliases is None:
|
||||
deprecated_aliases = {}
|
||||
|
||||
namespace_set = set(namespace)
|
||||
aliases_set = set(aliases)
|
||||
deprecated_aliases_set = set(deprecated_aliases)
|
||||
|
||||
assert not namespace_set & aliases_set, "namespace conflict"
|
||||
assert not namespace_set & deprecated_aliases_set, "namespace conflict"
|
||||
assert not aliases_set & deprecated_aliases_set, "namespace conflict"
|
||||
|
||||
package = namespace["__name__"]
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
assert aliases is not None # mypy cannot figure this out
|
||||
try:
|
||||
source = aliases[name]
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
return import_name(name, source, namespace)
|
||||
|
||||
assert deprecated_aliases is not None # mypy cannot figure this out
|
||||
try:
|
||||
source = deprecated_aliases[name]
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
warnings.warn(
|
||||
f"{package}.{name} is deprecated",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return import_name(name, source, namespace)
|
||||
|
||||
raise AttributeError(f"module {package!r} has no attribute {name!r}")
|
||||
|
||||
namespace["__getattr__"] = __getattr__
|
||||
|
||||
def __dir__() -> Iterable[str]:
|
||||
return sorted(namespace_set | aliases_set | deprecated_aliases_set)
|
||||
|
||||
namespace["__dir__"] = __dir__
|
@ -1,190 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import hmac
|
||||
import http
|
||||
from typing import Any, Awaitable, Callable, Iterable, Tuple, cast
|
||||
|
||||
from ..datastructures import Headers
|
||||
from ..exceptions import InvalidHeader
|
||||
from ..headers import build_www_authenticate_basic, parse_authorization_basic
|
||||
from .server import HTTPResponse, WebSocketServerProtocol
|
||||
|
||||
|
||||
__all__ = ["BasicAuthWebSocketServerProtocol", "basic_auth_protocol_factory"]
|
||||
|
||||
# Change to tuple[str, str] when dropping Python < 3.9.
|
||||
Credentials = Tuple[str, str]
|
||||
|
||||
|
||||
def is_credentials(value: Any) -> bool:
|
||||
try:
|
||||
username, password = value
|
||||
except (TypeError, ValueError):
|
||||
return False
|
||||
else:
|
||||
return isinstance(username, str) and isinstance(password, str)
|
||||
|
||||
|
||||
class BasicAuthWebSocketServerProtocol(WebSocketServerProtocol):
|
||||
"""
|
||||
WebSocket server protocol that enforces HTTP Basic Auth.
|
||||
|
||||
"""
|
||||
|
||||
realm: str = ""
|
||||
"""
|
||||
Scope of protection.
|
||||
|
||||
If provided, it should contain only ASCII characters because the
|
||||
encoding of non-ASCII characters is undefined.
|
||||
"""
|
||||
|
||||
username: str | None = None
|
||||
"""Username of the authenticated user."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args: Any,
|
||||
realm: str | None = None,
|
||||
check_credentials: Callable[[str, str], Awaitable[bool]] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
if realm is not None:
|
||||
self.realm = realm # shadow class attribute
|
||||
self._check_credentials = check_credentials
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
async def check_credentials(self, username: str, password: str) -> bool:
|
||||
"""
|
||||
Check whether credentials are authorized.
|
||||
|
||||
This coroutine may be overridden in a subclass, for example to
|
||||
authenticate against a database or an external service.
|
||||
|
||||
Args:
|
||||
username: HTTP Basic Auth username.
|
||||
password: HTTP Basic Auth password.
|
||||
|
||||
Returns:
|
||||
:obj:`True` if the handshake should continue;
|
||||
:obj:`False` if it should fail with an HTTP 401 error.
|
||||
|
||||
"""
|
||||
if self._check_credentials is not None:
|
||||
return await self._check_credentials(username, password)
|
||||
|
||||
return False
|
||||
|
||||
async def process_request(
|
||||
self,
|
||||
path: str,
|
||||
request_headers: Headers,
|
||||
) -> HTTPResponse | None:
|
||||
"""
|
||||
Check HTTP Basic Auth and return an HTTP 401 response if needed.
|
||||
|
||||
"""
|
||||
try:
|
||||
authorization = request_headers["Authorization"]
|
||||
except KeyError:
|
||||
return (
|
||||
http.HTTPStatus.UNAUTHORIZED,
|
||||
[("WWW-Authenticate", build_www_authenticate_basic(self.realm))],
|
||||
b"Missing credentials\n",
|
||||
)
|
||||
|
||||
try:
|
||||
username, password = parse_authorization_basic(authorization)
|
||||
except InvalidHeader:
|
||||
return (
|
||||
http.HTTPStatus.UNAUTHORIZED,
|
||||
[("WWW-Authenticate", build_www_authenticate_basic(self.realm))],
|
||||
b"Unsupported credentials\n",
|
||||
)
|
||||
|
||||
if not await self.check_credentials(username, password):
|
||||
return (
|
||||
http.HTTPStatus.UNAUTHORIZED,
|
||||
[("WWW-Authenticate", build_www_authenticate_basic(self.realm))],
|
||||
b"Invalid credentials\n",
|
||||
)
|
||||
|
||||
self.username = username
|
||||
|
||||
return await super().process_request(path, request_headers)
|
||||
|
||||
|
||||
def basic_auth_protocol_factory(
|
||||
realm: str | None = None,
|
||||
credentials: Credentials | Iterable[Credentials] | None = None,
|
||||
check_credentials: Callable[[str, str], Awaitable[bool]] | None = None,
|
||||
create_protocol: Callable[..., BasicAuthWebSocketServerProtocol] | None = None,
|
||||
) -> Callable[..., BasicAuthWebSocketServerProtocol]:
|
||||
"""
|
||||
Protocol factory that enforces HTTP Basic Auth.
|
||||
|
||||
:func:`basic_auth_protocol_factory` is designed to integrate with
|
||||
:func:`~websockets.legacy.server.serve` like this::
|
||||
|
||||
serve(
|
||||
...,
|
||||
create_protocol=basic_auth_protocol_factory(
|
||||
realm="my dev server",
|
||||
credentials=("hello", "iloveyou"),
|
||||
)
|
||||
)
|
||||
|
||||
Args:
|
||||
realm: Scope of protection. It should contain only ASCII characters
|
||||
because the encoding of non-ASCII characters is undefined.
|
||||
Refer to section 2.2 of :rfc:`7235` for details.
|
||||
credentials: Hard coded authorized credentials. It can be a
|
||||
``(username, password)`` pair or a list of such pairs.
|
||||
check_credentials: Coroutine that verifies credentials.
|
||||
It receives ``username`` and ``password`` arguments
|
||||
and returns a :class:`bool`. One of ``credentials`` or
|
||||
``check_credentials`` must be provided but not both.
|
||||
create_protocol: Factory that creates the protocol. By default, this
|
||||
is :class:`BasicAuthWebSocketServerProtocol`. It can be replaced
|
||||
by a subclass.
|
||||
Raises:
|
||||
TypeError: If the ``credentials`` or ``check_credentials`` argument is
|
||||
wrong.
|
||||
|
||||
"""
|
||||
if (credentials is None) == (check_credentials is None):
|
||||
raise TypeError("provide either credentials or check_credentials")
|
||||
|
||||
if credentials is not None:
|
||||
if is_credentials(credentials):
|
||||
credentials_list = [cast(Credentials, credentials)]
|
||||
elif isinstance(credentials, Iterable):
|
||||
credentials_list = list(cast(Iterable[Credentials], credentials))
|
||||
if not all(is_credentials(item) for item in credentials_list):
|
||||
raise TypeError(f"invalid credentials argument: {credentials}")
|
||||
else:
|
||||
raise TypeError(f"invalid credentials argument: {credentials}")
|
||||
|
||||
credentials_dict = dict(credentials_list)
|
||||
|
||||
async def check_credentials(username: str, password: str) -> bool:
|
||||
try:
|
||||
expected_password = credentials_dict[username]
|
||||
except KeyError:
|
||||
return False
|
||||
return hmac.compare_digest(expected_password, password)
|
||||
|
||||
if create_protocol is None:
|
||||
create_protocol = BasicAuthWebSocketServerProtocol
|
||||
|
||||
# Help mypy and avoid this error: "type[BasicAuthWebSocketServerProtocol] |
|
||||
# Callable[..., BasicAuthWebSocketServerProtocol]" not callable [misc]
|
||||
create_protocol = cast(
|
||||
Callable[..., BasicAuthWebSocketServerProtocol], create_protocol
|
||||
)
|
||||
return functools.partial(
|
||||
create_protocol,
|
||||
realm=realm,
|
||||
check_credentials=check_credentials,
|
||||
)
|
@ -1,707 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import urllib.parse
|
||||
import warnings
|
||||
from types import TracebackType
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Generator,
|
||||
Sequence,
|
||||
cast,
|
||||
)
|
||||
|
||||
from ..asyncio.compatibility import asyncio_timeout
|
||||
from ..datastructures import Headers, HeadersLike
|
||||
from ..exceptions import (
|
||||
InvalidHeader,
|
||||
InvalidHeaderValue,
|
||||
NegotiationError,
|
||||
SecurityError,
|
||||
)
|
||||
from ..extensions import ClientExtensionFactory, Extension
|
||||
from ..extensions.permessage_deflate import enable_client_permessage_deflate
|
||||
from ..headers import (
|
||||
build_authorization_basic,
|
||||
build_extension,
|
||||
build_host,
|
||||
build_subprotocol,
|
||||
parse_extension,
|
||||
parse_subprotocol,
|
||||
validate_subprotocols,
|
||||
)
|
||||
from ..http11 import USER_AGENT
|
||||
from ..typing import ExtensionHeader, LoggerLike, Origin, Subprotocol
|
||||
from ..uri import WebSocketURI, parse_uri
|
||||
from .exceptions import InvalidMessage, InvalidStatusCode, RedirectHandshake
|
||||
from .handshake import build_request, check_response
|
||||
from .http import read_response
|
||||
from .protocol import WebSocketCommonProtocol
|
||||
|
||||
|
||||
__all__ = ["connect", "unix_connect", "WebSocketClientProtocol"]
|
||||
|
||||
|
||||
class WebSocketClientProtocol(WebSocketCommonProtocol):
|
||||
"""
|
||||
WebSocket client connection.
|
||||
|
||||
:class:`WebSocketClientProtocol` provides :meth:`recv` and :meth:`send`
|
||||
coroutines for receiving and sending messages.
|
||||
|
||||
It supports asynchronous iteration to receive messages::
|
||||
|
||||
async for message in websocket:
|
||||
await process(message)
|
||||
|
||||
The iterator exits normally when the connection is closed with close code
|
||||
1000 (OK) or 1001 (going away) or without a close code. It raises
|
||||
a :exc:`~websockets.exceptions.ConnectionClosedError` when the connection
|
||||
is closed with any other code.
|
||||
|
||||
See :func:`connect` for the documentation of ``logger``, ``origin``,
|
||||
``extensions``, ``subprotocols``, ``extra_headers``, and
|
||||
``user_agent_header``.
|
||||
|
||||
See :class:`~websockets.legacy.protocol.WebSocketCommonProtocol` for the
|
||||
documentation of ``ping_interval``, ``ping_timeout``, ``close_timeout``,
|
||||
``max_size``, ``max_queue``, ``read_limit``, and ``write_limit``.
|
||||
|
||||
"""
|
||||
|
||||
is_client = True
|
||||
side = "client"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
logger: LoggerLike | None = None,
|
||||
origin: Origin | None = None,
|
||||
extensions: Sequence[ClientExtensionFactory] | None = None,
|
||||
subprotocols: Sequence[Subprotocol] | None = None,
|
||||
extra_headers: HeadersLike | None = None,
|
||||
user_agent_header: str | None = USER_AGENT,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
if logger is None:
|
||||
logger = logging.getLogger("websockets.client")
|
||||
super().__init__(logger=logger, **kwargs)
|
||||
self.origin = origin
|
||||
self.available_extensions = extensions
|
||||
self.available_subprotocols = subprotocols
|
||||
self.extra_headers = extra_headers
|
||||
self.user_agent_header = user_agent_header
|
||||
|
||||
def write_http_request(self, path: str, headers: Headers) -> None:
|
||||
"""
|
||||
Write request line and headers to the HTTP request.
|
||||
|
||||
"""
|
||||
self.path = path
|
||||
self.request_headers = headers
|
||||
|
||||
if self.debug:
|
||||
self.logger.debug("> GET %s HTTP/1.1", path)
|
||||
for key, value in headers.raw_items():
|
||||
self.logger.debug("> %s: %s", key, value)
|
||||
|
||||
# Since the path and headers only contain ASCII characters,
|
||||
# we can keep this simple.
|
||||
request = f"GET {path} HTTP/1.1\r\n"
|
||||
request += str(headers)
|
||||
|
||||
self.transport.write(request.encode())
|
||||
|
||||
async def read_http_response(self) -> tuple[int, Headers]:
|
||||
"""
|
||||
Read status line and headers from the HTTP response.
|
||||
|
||||
If the response contains a body, it may be read from ``self.reader``
|
||||
after this coroutine returns.
|
||||
|
||||
Raises:
|
||||
InvalidMessage: If the HTTP message is malformed or isn't an
|
||||
HTTP/1.1 GET response.
|
||||
|
||||
"""
|
||||
try:
|
||||
status_code, reason, headers = await read_response(self.reader)
|
||||
except Exception as exc:
|
||||
raise InvalidMessage("did not receive a valid HTTP response") from exc
|
||||
|
||||
if self.debug:
|
||||
self.logger.debug("< HTTP/1.1 %d %s", status_code, reason)
|
||||
for key, value in headers.raw_items():
|
||||
self.logger.debug("< %s: %s", key, value)
|
||||
|
||||
self.response_headers = headers
|
||||
|
||||
return status_code, self.response_headers
|
||||
|
||||
@staticmethod
|
||||
def process_extensions(
|
||||
headers: Headers,
|
||||
available_extensions: Sequence[ClientExtensionFactory] | None,
|
||||
) -> list[Extension]:
|
||||
"""
|
||||
Handle the Sec-WebSocket-Extensions HTTP response header.
|
||||
|
||||
Check that each extension is supported, as well as its parameters.
|
||||
|
||||
Return the list of accepted extensions.
|
||||
|
||||
Raise :exc:`~websockets.exceptions.InvalidHandshake` to abort the
|
||||
connection.
|
||||
|
||||
:rfc:`6455` leaves the rules up to the specification of each
|
||||
:extension.
|
||||
|
||||
To provide this level of flexibility, for each extension accepted by
|
||||
the server, we check for a match with each extension available in the
|
||||
client configuration. If no match is found, an exception is raised.
|
||||
|
||||
If several variants of the same extension are accepted by the server,
|
||||
it may be configured several times, which won't make sense in general.
|
||||
Extensions must implement their own requirements. For this purpose,
|
||||
the list of previously accepted extensions is provided.
|
||||
|
||||
Other requirements, for example related to mandatory extensions or the
|
||||
order of extensions, may be implemented by overriding this method.
|
||||
|
||||
"""
|
||||
accepted_extensions: list[Extension] = []
|
||||
|
||||
header_values = headers.get_all("Sec-WebSocket-Extensions")
|
||||
|
||||
if header_values:
|
||||
if available_extensions is None:
|
||||
raise NegotiationError("no extensions supported")
|
||||
|
||||
parsed_header_values: list[ExtensionHeader] = sum(
|
||||
[parse_extension(header_value) for header_value in header_values], []
|
||||
)
|
||||
|
||||
for name, response_params in parsed_header_values:
|
||||
for extension_factory in available_extensions:
|
||||
# Skip non-matching extensions based on their name.
|
||||
if extension_factory.name != name:
|
||||
continue
|
||||
|
||||
# Skip non-matching extensions based on their params.
|
||||
try:
|
||||
extension = extension_factory.process_response_params(
|
||||
response_params, accepted_extensions
|
||||
)
|
||||
except NegotiationError:
|
||||
continue
|
||||
|
||||
# Add matching extension to the final list.
|
||||
accepted_extensions.append(extension)
|
||||
|
||||
# Break out of the loop once we have a match.
|
||||
break
|
||||
|
||||
# If we didn't break from the loop, no extension in our list
|
||||
# matched what the server sent. Fail the connection.
|
||||
else:
|
||||
raise NegotiationError(
|
||||
f"Unsupported extension: "
|
||||
f"name = {name}, params = {response_params}"
|
||||
)
|
||||
|
||||
return accepted_extensions
|
||||
|
||||
@staticmethod
|
||||
def process_subprotocol(
|
||||
headers: Headers, available_subprotocols: Sequence[Subprotocol] | None
|
||||
) -> Subprotocol | None:
|
||||
"""
|
||||
Handle the Sec-WebSocket-Protocol HTTP response header.
|
||||
|
||||
Check that it contains exactly one supported subprotocol.
|
||||
|
||||
Return the selected subprotocol.
|
||||
|
||||
"""
|
||||
subprotocol: Subprotocol | None = None
|
||||
|
||||
header_values = headers.get_all("Sec-WebSocket-Protocol")
|
||||
|
||||
if header_values:
|
||||
if available_subprotocols is None:
|
||||
raise NegotiationError("no subprotocols supported")
|
||||
|
||||
parsed_header_values: Sequence[Subprotocol] = sum(
|
||||
[parse_subprotocol(header_value) for header_value in header_values], []
|
||||
)
|
||||
|
||||
if len(parsed_header_values) > 1:
|
||||
raise InvalidHeaderValue(
|
||||
"Sec-WebSocket-Protocol",
|
||||
f"multiple values: {', '.join(parsed_header_values)}",
|
||||
)
|
||||
|
||||
subprotocol = parsed_header_values[0]
|
||||
|
||||
if subprotocol not in available_subprotocols:
|
||||
raise NegotiationError(f"unsupported subprotocol: {subprotocol}")
|
||||
|
||||
return subprotocol
|
||||
|
||||
async def handshake(
|
||||
self,
|
||||
wsuri: WebSocketURI,
|
||||
origin: Origin | None = None,
|
||||
available_extensions: Sequence[ClientExtensionFactory] | None = None,
|
||||
available_subprotocols: Sequence[Subprotocol] | None = None,
|
||||
extra_headers: HeadersLike | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Perform the client side of the opening handshake.
|
||||
|
||||
Args:
|
||||
wsuri: URI of the WebSocket server.
|
||||
origin: Value of the ``Origin`` header.
|
||||
extensions: List of supported extensions, in order in which they
|
||||
should be negotiated and run.
|
||||
subprotocols: List of supported subprotocols, in order of decreasing
|
||||
preference.
|
||||
extra_headers: Arbitrary HTTP headers to add to the handshake request.
|
||||
|
||||
Raises:
|
||||
InvalidHandshake: If the handshake fails.
|
||||
|
||||
"""
|
||||
request_headers = Headers()
|
||||
|
||||
request_headers["Host"] = build_host(wsuri.host, wsuri.port, wsuri.secure)
|
||||
|
||||
if wsuri.user_info:
|
||||
request_headers["Authorization"] = build_authorization_basic(
|
||||
*wsuri.user_info
|
||||
)
|
||||
|
||||
if origin is not None:
|
||||
request_headers["Origin"] = origin
|
||||
|
||||
key = build_request(request_headers)
|
||||
|
||||
if available_extensions is not None:
|
||||
extensions_header = build_extension(
|
||||
[
|
||||
(extension_factory.name, extension_factory.get_request_params())
|
||||
for extension_factory in available_extensions
|
||||
]
|
||||
)
|
||||
request_headers["Sec-WebSocket-Extensions"] = extensions_header
|
||||
|
||||
if available_subprotocols is not None:
|
||||
protocol_header = build_subprotocol(available_subprotocols)
|
||||
request_headers["Sec-WebSocket-Protocol"] = protocol_header
|
||||
|
||||
if self.extra_headers is not None:
|
||||
request_headers.update(self.extra_headers)
|
||||
|
||||
if self.user_agent_header:
|
||||
request_headers.setdefault("User-Agent", self.user_agent_header)
|
||||
|
||||
self.write_http_request(wsuri.resource_name, request_headers)
|
||||
|
||||
status_code, response_headers = await self.read_http_response()
|
||||
if status_code in (301, 302, 303, 307, 308):
|
||||
if "Location" not in response_headers:
|
||||
raise InvalidHeader("Location")
|
||||
raise RedirectHandshake(response_headers["Location"])
|
||||
elif status_code != 101:
|
||||
raise InvalidStatusCode(status_code, response_headers)
|
||||
|
||||
check_response(response_headers, key)
|
||||
|
||||
self.extensions = self.process_extensions(
|
||||
response_headers, available_extensions
|
||||
)
|
||||
|
||||
self.subprotocol = self.process_subprotocol(
|
||||
response_headers, available_subprotocols
|
||||
)
|
||||
|
||||
self.connection_open()
|
||||
|
||||
|
||||
class Connect:
|
||||
"""
|
||||
Connect to the WebSocket server at ``uri``.
|
||||
|
||||
Awaiting :func:`connect` yields a :class:`WebSocketClientProtocol` which
|
||||
can then be used to send and receive messages.
|
||||
|
||||
:func:`connect` can be used as a asynchronous context manager::
|
||||
|
||||
async with connect(...) as websocket:
|
||||
...
|
||||
|
||||
The connection is closed automatically when exiting the context.
|
||||
|
||||
:func:`connect` can be used as an infinite asynchronous iterator to
|
||||
reconnect automatically on errors::
|
||||
|
||||
async for websocket in connect(...):
|
||||
try:
|
||||
...
|
||||
except websockets.ConnectionClosed:
|
||||
continue
|
||||
|
||||
The connection is closed automatically after each iteration of the loop.
|
||||
|
||||
If an error occurs while establishing the connection, :func:`connect`
|
||||
retries with exponential backoff. The backoff delay starts at three
|
||||
seconds and increases up to one minute.
|
||||
|
||||
If an error occurs in the body of the loop, you can handle the exception
|
||||
and :func:`connect` will reconnect with the next iteration; or you can
|
||||
let the exception bubble up and break out of the loop. This lets you
|
||||
decide which errors trigger a reconnection and which errors are fatal.
|
||||
|
||||
Args:
|
||||
uri: URI of the WebSocket server.
|
||||
create_protocol: Factory for the :class:`asyncio.Protocol` managing
|
||||
the connection. It defaults to :class:`WebSocketClientProtocol`.
|
||||
Set it to a wrapper or a subclass to customize connection handling.
|
||||
logger: Logger for this client.
|
||||
It defaults to ``logging.getLogger("websockets.client")``.
|
||||
See the :doc:`logging guide <../../topics/logging>` for details.
|
||||
compression: The "permessage-deflate" extension is enabled by default.
|
||||
Set ``compression`` to :obj:`None` to disable it. See the
|
||||
:doc:`compression guide <../../topics/compression>` for details.
|
||||
origin: Value of the ``Origin`` header, for servers that require it.
|
||||
extensions: List of supported extensions, in order in which they
|
||||
should be negotiated and run.
|
||||
subprotocols: List of supported subprotocols, in order of decreasing
|
||||
preference.
|
||||
extra_headers: Arbitrary HTTP headers to add to the handshake request.
|
||||
user_agent_header: Value of the ``User-Agent`` request header.
|
||||
It defaults to ``"Python/x.y.z websockets/X.Y"``.
|
||||
Setting it to :obj:`None` removes the header.
|
||||
open_timeout: Timeout for opening the connection in seconds.
|
||||
:obj:`None` disables the timeout.
|
||||
|
||||
See :class:`~websockets.legacy.protocol.WebSocketCommonProtocol` for the
|
||||
documentation of ``ping_interval``, ``ping_timeout``, ``close_timeout``,
|
||||
``max_size``, ``max_queue``, ``read_limit``, and ``write_limit``.
|
||||
|
||||
Any other keyword arguments are passed the event loop's
|
||||
:meth:`~asyncio.loop.create_connection` method.
|
||||
|
||||
For example:
|
||||
|
||||
* You can set ``ssl`` to a :class:`~ssl.SSLContext` to enforce TLS
|
||||
settings. When connecting to a ``wss://`` URI, if ``ssl`` isn't
|
||||
provided, a TLS context is created
|
||||
with :func:`~ssl.create_default_context`.
|
||||
|
||||
* You can set ``host`` and ``port`` to connect to a different host and
|
||||
port from those found in ``uri``. This only changes the destination of
|
||||
the TCP connection. The host name from ``uri`` is still used in the TLS
|
||||
handshake for secure connections and in the ``Host`` header.
|
||||
|
||||
Raises:
|
||||
InvalidURI: If ``uri`` isn't a valid WebSocket URI.
|
||||
OSError: If the TCP connection fails.
|
||||
InvalidHandshake: If the opening handshake fails.
|
||||
~asyncio.TimeoutError: If the opening handshake times out.
|
||||
|
||||
"""
|
||||
|
||||
MAX_REDIRECTS_ALLOWED = int(os.environ.get("WEBSOCKETS_MAX_REDIRECTS", "10"))
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
*,
|
||||
create_protocol: Callable[..., WebSocketClientProtocol] | None = None,
|
||||
logger: LoggerLike | None = None,
|
||||
compression: str | None = "deflate",
|
||||
origin: Origin | None = None,
|
||||
extensions: Sequence[ClientExtensionFactory] | None = None,
|
||||
subprotocols: Sequence[Subprotocol] | None = None,
|
||||
extra_headers: HeadersLike | None = None,
|
||||
user_agent_header: str | None = USER_AGENT,
|
||||
open_timeout: float | None = 10,
|
||||
ping_interval: float | None = 20,
|
||||
ping_timeout: float | None = 20,
|
||||
close_timeout: float | None = None,
|
||||
max_size: int | None = 2**20,
|
||||
max_queue: int | None = 2**5,
|
||||
read_limit: int = 2**16,
|
||||
write_limit: int = 2**16,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
# Backwards compatibility: close_timeout used to be called timeout.
|
||||
timeout: float | None = kwargs.pop("timeout", None)
|
||||
if timeout is None:
|
||||
timeout = 10
|
||||
else:
|
||||
warnings.warn("rename timeout to close_timeout", DeprecationWarning)
|
||||
# If both are specified, timeout is ignored.
|
||||
if close_timeout is None:
|
||||
close_timeout = timeout
|
||||
|
||||
# Backwards compatibility: create_protocol used to be called klass.
|
||||
klass: type[WebSocketClientProtocol] | None = kwargs.pop("klass", None)
|
||||
if klass is None:
|
||||
klass = WebSocketClientProtocol
|
||||
else:
|
||||
warnings.warn("rename klass to create_protocol", DeprecationWarning)
|
||||
# If both are specified, klass is ignored.
|
||||
if create_protocol is None:
|
||||
create_protocol = klass
|
||||
|
||||
# Backwards compatibility: recv() used to return None on closed connections
|
||||
legacy_recv: bool = kwargs.pop("legacy_recv", False)
|
||||
|
||||
# Backwards compatibility: the loop parameter used to be supported.
|
||||
_loop: asyncio.AbstractEventLoop | None = kwargs.pop("loop", None)
|
||||
if _loop is None:
|
||||
loop = asyncio.get_event_loop()
|
||||
else:
|
||||
loop = _loop
|
||||
warnings.warn("remove loop argument", DeprecationWarning)
|
||||
|
||||
wsuri = parse_uri(uri)
|
||||
if wsuri.secure:
|
||||
kwargs.setdefault("ssl", True)
|
||||
elif kwargs.get("ssl") is not None:
|
||||
raise ValueError(
|
||||
"connect() received a ssl argument for a ws:// URI, "
|
||||
"use a wss:// URI to enable TLS"
|
||||
)
|
||||
|
||||
if compression == "deflate":
|
||||
extensions = enable_client_permessage_deflate(extensions)
|
||||
elif compression is not None:
|
||||
raise ValueError(f"unsupported compression: {compression}")
|
||||
|
||||
if subprotocols is not None:
|
||||
validate_subprotocols(subprotocols)
|
||||
|
||||
# Help mypy and avoid this error: "type[WebSocketClientProtocol] |
|
||||
# Callable[..., WebSocketClientProtocol]" not callable [misc]
|
||||
create_protocol = cast(Callable[..., WebSocketClientProtocol], create_protocol)
|
||||
factory = functools.partial(
|
||||
create_protocol,
|
||||
logger=logger,
|
||||
origin=origin,
|
||||
extensions=extensions,
|
||||
subprotocols=subprotocols,
|
||||
extra_headers=extra_headers,
|
||||
user_agent_header=user_agent_header,
|
||||
ping_interval=ping_interval,
|
||||
ping_timeout=ping_timeout,
|
||||
close_timeout=close_timeout,
|
||||
max_size=max_size,
|
||||
max_queue=max_queue,
|
||||
read_limit=read_limit,
|
||||
write_limit=write_limit,
|
||||
host=wsuri.host,
|
||||
port=wsuri.port,
|
||||
secure=wsuri.secure,
|
||||
legacy_recv=legacy_recv,
|
||||
loop=_loop,
|
||||
)
|
||||
|
||||
if kwargs.pop("unix", False):
|
||||
path: str | None = kwargs.pop("path", None)
|
||||
create_connection = functools.partial(
|
||||
loop.create_unix_connection, factory, path, **kwargs
|
||||
)
|
||||
else:
|
||||
host: str | None
|
||||
port: int | None
|
||||
if kwargs.get("sock") is None:
|
||||
host, port = wsuri.host, wsuri.port
|
||||
else:
|
||||
# If sock is given, host and port shouldn't be specified.
|
||||
host, port = None, None
|
||||
if kwargs.get("ssl"):
|
||||
kwargs.setdefault("server_hostname", wsuri.host)
|
||||
# If host and port are given, override values from the URI.
|
||||
host = kwargs.pop("host", host)
|
||||
port = kwargs.pop("port", port)
|
||||
create_connection = functools.partial(
|
||||
loop.create_connection, factory, host, port, **kwargs
|
||||
)
|
||||
|
||||
self.open_timeout = open_timeout
|
||||
if logger is None:
|
||||
logger = logging.getLogger("websockets.client")
|
||||
self.logger = logger
|
||||
|
||||
# This is a coroutine function.
|
||||
self._create_connection = create_connection
|
||||
self._uri = uri
|
||||
self._wsuri = wsuri
|
||||
|
||||
def handle_redirect(self, uri: str) -> None:
|
||||
# Update the state of this instance to connect to a new URI.
|
||||
old_uri = self._uri
|
||||
old_wsuri = self._wsuri
|
||||
new_uri = urllib.parse.urljoin(old_uri, uri)
|
||||
new_wsuri = parse_uri(new_uri)
|
||||
|
||||
# Forbid TLS downgrade.
|
||||
if old_wsuri.secure and not new_wsuri.secure:
|
||||
raise SecurityError("redirect from WSS to WS")
|
||||
|
||||
same_origin = (
|
||||
old_wsuri.secure == new_wsuri.secure
|
||||
and old_wsuri.host == new_wsuri.host
|
||||
and old_wsuri.port == new_wsuri.port
|
||||
)
|
||||
|
||||
# Rewrite secure, host, and port for cross-origin redirects.
|
||||
# This preserves connection overrides with the host and port
|
||||
# arguments if the redirect points to the same host and port.
|
||||
if not same_origin:
|
||||
factory = self._create_connection.args[0]
|
||||
# Support TLS upgrade.
|
||||
if not old_wsuri.secure and new_wsuri.secure:
|
||||
factory.keywords["secure"] = True
|
||||
self._create_connection.keywords.setdefault("ssl", True)
|
||||
# Replace secure, host, and port arguments of the protocol factory.
|
||||
factory = functools.partial(
|
||||
factory.func,
|
||||
*factory.args,
|
||||
**dict(factory.keywords, host=new_wsuri.host, port=new_wsuri.port),
|
||||
)
|
||||
# Replace secure, host, and port arguments of create_connection.
|
||||
self._create_connection = functools.partial(
|
||||
self._create_connection.func,
|
||||
*(factory, new_wsuri.host, new_wsuri.port),
|
||||
**self._create_connection.keywords,
|
||||
)
|
||||
|
||||
# Set the new WebSocket URI. This suffices for same-origin redirects.
|
||||
self._uri = new_uri
|
||||
self._wsuri = new_wsuri
|
||||
|
||||
# async for ... in connect(...):
|
||||
|
||||
BACKOFF_INITIAL = float(os.environ.get("WEBSOCKETS_BACKOFF_INITIAL_DELAY", "5"))
|
||||
BACKOFF_MIN = float(os.environ.get("WEBSOCKETS_BACKOFF_MIN_DELAY", "3.1"))
|
||||
BACKOFF_MAX = float(os.environ.get("WEBSOCKETS_BACKOFF_MAX_DELAY", "90.0"))
|
||||
BACKOFF_FACTOR = float(os.environ.get("WEBSOCKETS_BACKOFF_FACTOR", "1.618"))
|
||||
|
||||
async def __aiter__(self) -> AsyncIterator[WebSocketClientProtocol]:
|
||||
backoff_delay = self.BACKOFF_MIN / self.BACKOFF_FACTOR
|
||||
while True:
|
||||
try:
|
||||
async with self as protocol:
|
||||
yield protocol
|
||||
except Exception:
|
||||
# Add a random initial delay between 0 and 5 seconds.
|
||||
# See 7.2.3. Recovering from Abnormal Closure in RFC 6455.
|
||||
if backoff_delay == self.BACKOFF_MIN:
|
||||
initial_delay = random.random() * self.BACKOFF_INITIAL
|
||||
self.logger.info(
|
||||
"! connect failed; reconnecting in %.1f seconds",
|
||||
initial_delay,
|
||||
exc_info=True,
|
||||
)
|
||||
await asyncio.sleep(initial_delay)
|
||||
else:
|
||||
self.logger.info(
|
||||
"! connect failed again; retrying in %d seconds",
|
||||
int(backoff_delay),
|
||||
exc_info=True,
|
||||
)
|
||||
await asyncio.sleep(int(backoff_delay))
|
||||
# Increase delay with truncated exponential backoff.
|
||||
backoff_delay = backoff_delay * self.BACKOFF_FACTOR
|
||||
backoff_delay = min(backoff_delay, self.BACKOFF_MAX)
|
||||
continue
|
||||
else:
|
||||
# Connection succeeded - reset backoff delay
|
||||
backoff_delay = self.BACKOFF_MIN
|
||||
|
||||
# async with connect(...) as ...:
|
||||
|
||||
async def __aenter__(self) -> WebSocketClientProtocol:
|
||||
return await self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
traceback: TracebackType | None,
|
||||
) -> None:
|
||||
await self.protocol.close()
|
||||
|
||||
# ... = await connect(...)
|
||||
|
||||
def __await__(self) -> Generator[Any, None, WebSocketClientProtocol]:
|
||||
# Create a suitable iterator by calling __await__ on a coroutine.
|
||||
return self.__await_impl__().__await__()
|
||||
|
||||
async def __await_impl__(self) -> WebSocketClientProtocol:
|
||||
async with asyncio_timeout(self.open_timeout):
|
||||
for _redirects in range(self.MAX_REDIRECTS_ALLOWED):
|
||||
_transport, protocol = await self._create_connection()
|
||||
try:
|
||||
await protocol.handshake(
|
||||
self._wsuri,
|
||||
origin=protocol.origin,
|
||||
available_extensions=protocol.available_extensions,
|
||||
available_subprotocols=protocol.available_subprotocols,
|
||||
extra_headers=protocol.extra_headers,
|
||||
)
|
||||
except RedirectHandshake as exc:
|
||||
protocol.fail_connection()
|
||||
await protocol.wait_closed()
|
||||
self.handle_redirect(exc.uri)
|
||||
# Avoid leaking a connected socket when the handshake fails.
|
||||
except (Exception, asyncio.CancelledError):
|
||||
protocol.fail_connection()
|
||||
await protocol.wait_closed()
|
||||
raise
|
||||
else:
|
||||
self.protocol = protocol
|
||||
return protocol
|
||||
else:
|
||||
raise SecurityError("too many redirects")
|
||||
|
||||
# ... = yield from connect(...) - remove when dropping Python < 3.10
|
||||
|
||||
__iter__ = __await__
|
||||
|
||||
|
||||
connect = Connect
|
||||
|
||||
|
||||
def unix_connect(
|
||||
path: str | None = None,
|
||||
uri: str = "ws://localhost/",
|
||||
**kwargs: Any,
|
||||
) -> Connect:
|
||||
"""
|
||||
Similar to :func:`connect`, but for connecting to a Unix socket.
|
||||
|
||||
This function builds upon the event loop's
|
||||
:meth:`~asyncio.loop.create_unix_connection` method.
|
||||
|
||||
It is only available on Unix.
|
||||
|
||||
It's mainly useful for debugging servers listening on Unix sockets.
|
||||
|
||||
Args:
|
||||
path: File system path to the Unix socket.
|
||||
uri: URI of the WebSocket server; the host is used in the TLS
|
||||
handshake for secure connections and in the ``Host`` header.
|
||||
|
||||
"""
|
||||
return connect(uri=uri, path=path, unix=True, **kwargs)
|
@ -1,78 +0,0 @@
|
||||
import http
|
||||
|
||||
from .. import datastructures
|
||||
from ..exceptions import (
|
||||
InvalidHandshake,
|
||||
ProtocolError as WebSocketProtocolError, # noqa: F401
|
||||
)
|
||||
from ..typing import StatusLike
|
||||
|
||||
|
||||
class InvalidMessage(InvalidHandshake):
|
||||
"""
|
||||
Raised when a handshake request or response is malformed.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class InvalidStatusCode(InvalidHandshake):
|
||||
"""
|
||||
Raised when a handshake response status code is invalid.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, status_code: int, headers: datastructures.Headers) -> None:
|
||||
self.status_code = status_code
|
||||
self.headers = headers
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"server rejected WebSocket connection: HTTP {self.status_code}"
|
||||
|
||||
|
||||
class AbortHandshake(InvalidHandshake):
|
||||
"""
|
||||
Raised to abort the handshake on purpose and return an HTTP response.
|
||||
|
||||
This exception is an implementation detail.
|
||||
|
||||
The public API is
|
||||
:meth:`~websockets.legacy.server.WebSocketServerProtocol.process_request`.
|
||||
|
||||
Attributes:
|
||||
status (~http.HTTPStatus): HTTP status code.
|
||||
headers (Headers): HTTP response headers.
|
||||
body (bytes): HTTP response body.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
status: StatusLike,
|
||||
headers: datastructures.HeadersLike,
|
||||
body: bytes = b"",
|
||||
) -> None:
|
||||
# If a user passes an int instead of a HTTPStatus, fix it automatically.
|
||||
self.status = http.HTTPStatus(status)
|
||||
self.headers = datastructures.Headers(headers)
|
||||
self.body = body
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f"HTTP {self.status:d}, "
|
||||
f"{len(self.headers)} headers, "
|
||||
f"{len(self.body)} bytes"
|
||||
)
|
||||
|
||||
|
||||
class RedirectHandshake(InvalidHandshake):
|
||||
"""
|
||||
Raised when a handshake gets redirected.
|
||||
|
||||
This exception is an implementation detail.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, uri: str) -> None:
|
||||
self.uri = uri
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"redirect to {self.uri}"
|
@ -1,224 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import struct
|
||||
from typing import Any, Awaitable, Callable, NamedTuple, Sequence
|
||||
|
||||
from .. import extensions, frames
|
||||
from ..exceptions import PayloadTooBig, ProtocolError
|
||||
from ..frames import BytesLike
|
||||
from ..typing import Data
|
||||
|
||||
|
||||
try:
|
||||
from ..speedups import apply_mask
|
||||
except ImportError:
|
||||
from ..utils import apply_mask
|
||||
|
||||
|
||||
class Frame(NamedTuple):
|
||||
fin: bool
|
||||
opcode: frames.Opcode
|
||||
data: bytes
|
||||
rsv1: bool = False
|
||||
rsv2: bool = False
|
||||
rsv3: bool = False
|
||||
|
||||
@property
|
||||
def new_frame(self) -> frames.Frame:
|
||||
return frames.Frame(
|
||||
self.opcode,
|
||||
self.data,
|
||||
self.fin,
|
||||
self.rsv1,
|
||||
self.rsv2,
|
||||
self.rsv3,
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return str(self.new_frame)
|
||||
|
||||
def check(self) -> None:
|
||||
return self.new_frame.check()
|
||||
|
||||
@classmethod
|
||||
async def read(
|
||||
cls,
|
||||
reader: Callable[[int], Awaitable[bytes]],
|
||||
*,
|
||||
mask: bool,
|
||||
max_size: int | None = None,
|
||||
extensions: Sequence[extensions.Extension] | None = None,
|
||||
) -> Frame:
|
||||
"""
|
||||
Read a WebSocket frame.
|
||||
|
||||
Args:
|
||||
reader: Coroutine that reads exactly the requested number of
|
||||
bytes, unless the end of file is reached.
|
||||
mask: Whether the frame should be masked i.e. whether the read
|
||||
happens on the server side.
|
||||
max_size: Maximum payload size in bytes.
|
||||
extensions: List of extensions, applied in reverse order.
|
||||
|
||||
Raises:
|
||||
PayloadTooBig: If the frame exceeds ``max_size``.
|
||||
ProtocolError: If the frame contains incorrect values.
|
||||
|
||||
"""
|
||||
|
||||
# Read the header.
|
||||
data = await reader(2)
|
||||
head1, head2 = struct.unpack("!BB", data)
|
||||
|
||||
# While not Pythonic, this is marginally faster than calling bool().
|
||||
fin = True if head1 & 0b10000000 else False
|
||||
rsv1 = True if head1 & 0b01000000 else False
|
||||
rsv2 = True if head1 & 0b00100000 else False
|
||||
rsv3 = True if head1 & 0b00010000 else False
|
||||
|
||||
try:
|
||||
opcode = frames.Opcode(head1 & 0b00001111)
|
||||
except ValueError as exc:
|
||||
raise ProtocolError("invalid opcode") from exc
|
||||
|
||||
if (True if head2 & 0b10000000 else False) != mask:
|
||||
raise ProtocolError("incorrect masking")
|
||||
|
||||
length = head2 & 0b01111111
|
||||
if length == 126:
|
||||
data = await reader(2)
|
||||
(length,) = struct.unpack("!H", data)
|
||||
elif length == 127:
|
||||
data = await reader(8)
|
||||
(length,) = struct.unpack("!Q", data)
|
||||
if max_size is not None and length > max_size:
|
||||
raise PayloadTooBig(f"over size limit ({length} > {max_size} bytes)")
|
||||
if mask:
|
||||
mask_bits = await reader(4)
|
||||
|
||||
# Read the data.
|
||||
data = await reader(length)
|
||||
if mask:
|
||||
data = apply_mask(data, mask_bits)
|
||||
|
||||
new_frame = frames.Frame(opcode, data, fin, rsv1, rsv2, rsv3)
|
||||
|
||||
if extensions is None:
|
||||
extensions = []
|
||||
for extension in reversed(extensions):
|
||||
new_frame = extension.decode(new_frame, max_size=max_size)
|
||||
|
||||
new_frame.check()
|
||||
|
||||
return cls(
|
||||
new_frame.fin,
|
||||
new_frame.opcode,
|
||||
new_frame.data,
|
||||
new_frame.rsv1,
|
||||
new_frame.rsv2,
|
||||
new_frame.rsv3,
|
||||
)
|
||||
|
||||
def write(
|
||||
self,
|
||||
write: Callable[[bytes], Any],
|
||||
*,
|
||||
mask: bool,
|
||||
extensions: Sequence[extensions.Extension] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Write a WebSocket frame.
|
||||
|
||||
Args:
|
||||
frame: Frame to write.
|
||||
write: Function that writes bytes.
|
||||
mask: Whether the frame should be masked i.e. whether the write
|
||||
happens on the client side.
|
||||
extensions: List of extensions, applied in order.
|
||||
|
||||
Raises:
|
||||
ProtocolError: If the frame contains incorrect values.
|
||||
|
||||
"""
|
||||
# The frame is written in a single call to write in order to prevent
|
||||
# TCP fragmentation. See #68 for details. This also makes it safe to
|
||||
# send frames concurrently from multiple coroutines.
|
||||
write(self.new_frame.serialize(mask=mask, extensions=extensions))
|
||||
|
||||
|
||||
def prepare_data(data: Data) -> tuple[int, bytes]:
|
||||
"""
|
||||
Convert a string or byte-like object to an opcode and a bytes-like object.
|
||||
|
||||
This function is designed for data frames.
|
||||
|
||||
If ``data`` is a :class:`str`, return ``OP_TEXT`` and a :class:`bytes`
|
||||
object encoding ``data`` in UTF-8.
|
||||
|
||||
If ``data`` is a bytes-like object, return ``OP_BINARY`` and a bytes-like
|
||||
object.
|
||||
|
||||
Raises:
|
||||
TypeError: If ``data`` doesn't have a supported type.
|
||||
|
||||
"""
|
||||
if isinstance(data, str):
|
||||
return frames.Opcode.TEXT, data.encode()
|
||||
elif isinstance(data, BytesLike):
|
||||
return frames.Opcode.BINARY, data
|
||||
else:
|
||||
raise TypeError("data must be str or bytes-like")
|
||||
|
||||
|
||||
def prepare_ctrl(data: Data) -> bytes:
|
||||
"""
|
||||
Convert a string or byte-like object to bytes.
|
||||
|
||||
This function is designed for ping and pong frames.
|
||||
|
||||
If ``data`` is a :class:`str`, return a :class:`bytes` object encoding
|
||||
``data`` in UTF-8.
|
||||
|
||||
If ``data`` is a bytes-like object, return a :class:`bytes` object.
|
||||
|
||||
Raises:
|
||||
TypeError: If ``data`` doesn't have a supported type.
|
||||
|
||||
"""
|
||||
if isinstance(data, str):
|
||||
return data.encode()
|
||||
elif isinstance(data, BytesLike):
|
||||
return bytes(data)
|
||||
else:
|
||||
raise TypeError("data must be str or bytes-like")
|
||||
|
||||
|
||||
# Backwards compatibility with previously documented public APIs
|
||||
encode_data = prepare_ctrl
|
||||
|
||||
# Backwards compatibility with previously documented public APIs
|
||||
from ..frames import Close # noqa: E402 F401, I001
|
||||
|
||||
|
||||
def parse_close(data: bytes) -> tuple[int, str]:
|
||||
"""
|
||||
Parse the payload from a close frame.
|
||||
|
||||
Returns:
|
||||
Close code and reason.
|
||||
|
||||
Raises:
|
||||
ProtocolError: If data is ill-formed.
|
||||
UnicodeDecodeError: If the reason isn't valid UTF-8.
|
||||
|
||||
"""
|
||||
close = Close.parse(data)
|
||||
return close.code, close.reason
|
||||
|
||||
|
||||
def serialize_close(code: int, reason: str) -> bytes:
|
||||
"""
|
||||
Serialize the payload for a close frame.
|
||||
|
||||
"""
|
||||
return Close(code, reason).serialize()
|
@ -1,158 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import binascii
|
||||
|
||||
from ..datastructures import Headers, MultipleValuesError
|
||||
from ..exceptions import InvalidHeader, InvalidHeaderValue, InvalidUpgrade
|
||||
from ..headers import parse_connection, parse_upgrade
|
||||
from ..typing import ConnectionOption, UpgradeProtocol
|
||||
from ..utils import accept_key as accept, generate_key
|
||||
|
||||
|
||||
__all__ = ["build_request", "check_request", "build_response", "check_response"]
|
||||
|
||||
|
||||
def build_request(headers: Headers) -> str:
|
||||
"""
|
||||
Build a handshake request to send to the server.
|
||||
|
||||
Update request headers passed in argument.
|
||||
|
||||
Args:
|
||||
headers: Handshake request headers.
|
||||
|
||||
Returns:
|
||||
``key`` that must be passed to :func:`check_response`.
|
||||
|
||||
"""
|
||||
key = generate_key()
|
||||
headers["Upgrade"] = "websocket"
|
||||
headers["Connection"] = "Upgrade"
|
||||
headers["Sec-WebSocket-Key"] = key
|
||||
headers["Sec-WebSocket-Version"] = "13"
|
||||
return key
|
||||
|
||||
|
||||
def check_request(headers: Headers) -> str:
|
||||
"""
|
||||
Check a handshake request received from the client.
|
||||
|
||||
This function doesn't verify that the request is an HTTP/1.1 or higher GET
|
||||
request and doesn't perform ``Host`` and ``Origin`` checks. These controls
|
||||
are usually performed earlier in the HTTP request handling code. They're
|
||||
the responsibility of the caller.
|
||||
|
||||
Args:
|
||||
headers: Handshake request headers.
|
||||
|
||||
Returns:
|
||||
``key`` that must be passed to :func:`build_response`.
|
||||
|
||||
Raises:
|
||||
InvalidHandshake: If the handshake request is invalid.
|
||||
Then, the server must return a 400 Bad Request error.
|
||||
|
||||
"""
|
||||
connection: list[ConnectionOption] = sum(
|
||||
[parse_connection(value) for value in headers.get_all("Connection")], []
|
||||
)
|
||||
|
||||
if not any(value.lower() == "upgrade" for value in connection):
|
||||
raise InvalidUpgrade("Connection", ", ".join(connection))
|
||||
|
||||
upgrade: list[UpgradeProtocol] = sum(
|
||||
[parse_upgrade(value) for value in headers.get_all("Upgrade")], []
|
||||
)
|
||||
|
||||
# For compatibility with non-strict implementations, ignore case when
|
||||
# checking the Upgrade header. The RFC always uses "websocket", except
|
||||
# in section 11.2. (IANA registration) where it uses "WebSocket".
|
||||
if not (len(upgrade) == 1 and upgrade[0].lower() == "websocket"):
|
||||
raise InvalidUpgrade("Upgrade", ", ".join(upgrade))
|
||||
|
||||
try:
|
||||
s_w_key = headers["Sec-WebSocket-Key"]
|
||||
except KeyError as exc:
|
||||
raise InvalidHeader("Sec-WebSocket-Key") from exc
|
||||
except MultipleValuesError as exc:
|
||||
raise InvalidHeader("Sec-WebSocket-Key", "multiple values") from exc
|
||||
|
||||
try:
|
||||
raw_key = base64.b64decode(s_w_key.encode(), validate=True)
|
||||
except binascii.Error as exc:
|
||||
raise InvalidHeaderValue("Sec-WebSocket-Key", s_w_key) from exc
|
||||
if len(raw_key) != 16:
|
||||
raise InvalidHeaderValue("Sec-WebSocket-Key", s_w_key)
|
||||
|
||||
try:
|
||||
s_w_version = headers["Sec-WebSocket-Version"]
|
||||
except KeyError as exc:
|
||||
raise InvalidHeader("Sec-WebSocket-Version") from exc
|
||||
except MultipleValuesError as exc:
|
||||
raise InvalidHeader("Sec-WebSocket-Version", "multiple values") from exc
|
||||
|
||||
if s_w_version != "13":
|
||||
raise InvalidHeaderValue("Sec-WebSocket-Version", s_w_version)
|
||||
|
||||
return s_w_key
|
||||
|
||||
|
||||
def build_response(headers: Headers, key: str) -> None:
|
||||
"""
|
||||
Build a handshake response to send to the client.
|
||||
|
||||
Update response headers passed in argument.
|
||||
|
||||
Args:
|
||||
headers: Handshake response headers.
|
||||
key: Returned by :func:`check_request`.
|
||||
|
||||
"""
|
||||
headers["Upgrade"] = "websocket"
|
||||
headers["Connection"] = "Upgrade"
|
||||
headers["Sec-WebSocket-Accept"] = accept(key)
|
||||
|
||||
|
||||
def check_response(headers: Headers, key: str) -> None:
|
||||
"""
|
||||
Check a handshake response received from the server.
|
||||
|
||||
This function doesn't verify that the response is an HTTP/1.1 or higher
|
||||
response with a 101 status code. These controls are the responsibility of
|
||||
the caller.
|
||||
|
||||
Args:
|
||||
headers: Handshake response headers.
|
||||
key: Returned by :func:`build_request`.
|
||||
|
||||
Raises:
|
||||
InvalidHandshake: If the handshake response is invalid.
|
||||
|
||||
"""
|
||||
connection: list[ConnectionOption] = sum(
|
||||
[parse_connection(value) for value in headers.get_all("Connection")], []
|
||||
)
|
||||
|
||||
if not any(value.lower() == "upgrade" for value in connection):
|
||||
raise InvalidUpgrade("Connection", " ".join(connection))
|
||||
|
||||
upgrade: list[UpgradeProtocol] = sum(
|
||||
[parse_upgrade(value) for value in headers.get_all("Upgrade")], []
|
||||
)
|
||||
|
||||
# For compatibility with non-strict implementations, ignore case when
|
||||
# checking the Upgrade header. The RFC always uses "websocket", except
|
||||
# in section 11.2. (IANA registration) where it uses "WebSocket".
|
||||
if not (len(upgrade) == 1 and upgrade[0].lower() == "websocket"):
|
||||
raise InvalidUpgrade("Upgrade", ", ".join(upgrade))
|
||||
|
||||
try:
|
||||
s_w_accept = headers["Sec-WebSocket-Accept"]
|
||||
except KeyError as exc:
|
||||
raise InvalidHeader("Sec-WebSocket-Accept") from exc
|
||||
except MultipleValuesError as exc:
|
||||
raise InvalidHeader("Sec-WebSocket-Accept", "multiple values") from exc
|
||||
|
||||
if s_w_accept != accept(key):
|
||||
raise InvalidHeaderValue("Sec-WebSocket-Accept", s_w_accept)
|
@ -1,201 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
|
||||
from ..datastructures import Headers
|
||||
from ..exceptions import SecurityError
|
||||
|
||||
|
||||
__all__ = ["read_request", "read_response"]
|
||||
|
||||
MAX_NUM_HEADERS = int(os.environ.get("WEBSOCKETS_MAX_NUM_HEADERS", "128"))
|
||||
MAX_LINE_LENGTH = int(os.environ.get("WEBSOCKETS_MAX_LINE_LENGTH", "8192"))
|
||||
|
||||
|
||||
def d(value: bytes) -> str:
|
||||
"""
|
||||
Decode a bytestring for interpolating into an error message.
|
||||
|
||||
"""
|
||||
return value.decode(errors="backslashreplace")
|
||||
|
||||
|
||||
# See https://datatracker.ietf.org/doc/html/rfc7230#appendix-B.
|
||||
|
||||
# Regex for validating header names.
|
||||
|
||||
_token_re = re.compile(rb"[-!#$%&\'*+.^_`|~0-9a-zA-Z]+")
|
||||
|
||||
# Regex for validating header values.
|
||||
|
||||
# We don't attempt to support obsolete line folding.
|
||||
|
||||
# Include HTAB (\x09), SP (\x20), VCHAR (\x21-\x7e), obs-text (\x80-\xff).
|
||||
|
||||
# The ABNF is complicated because it attempts to express that optional
|
||||
# whitespace is ignored. We strip whitespace and don't revalidate that.
|
||||
|
||||
# See also https://www.rfc-editor.org/errata_search.php?rfc=7230&eid=4189
|
||||
|
||||
_value_re = re.compile(rb"[\x09\x20-\x7e\x80-\xff]*")
|
||||
|
||||
|
||||
async def read_request(stream: asyncio.StreamReader) -> tuple[str, Headers]:
|
||||
"""
|
||||
Read an HTTP/1.1 GET request and return ``(path, headers)``.
|
||||
|
||||
``path`` isn't URL-decoded or validated in any way.
|
||||
|
||||
``path`` and ``headers`` are expected to contain only ASCII characters.
|
||||
Other characters are represented with surrogate escapes.
|
||||
|
||||
:func:`read_request` doesn't attempt to read the request body because
|
||||
WebSocket handshake requests don't have one. If the request contains a
|
||||
body, it may be read from ``stream`` after this coroutine returns.
|
||||
|
||||
Args:
|
||||
stream: Input to read the request from.
|
||||
|
||||
Raises:
|
||||
EOFError: If the connection is closed without a full HTTP request.
|
||||
SecurityError: If the request exceeds a security limit.
|
||||
ValueError: If the request isn't well formatted.
|
||||
|
||||
"""
|
||||
# https://datatracker.ietf.org/doc/html/rfc7230#section-3.1.1
|
||||
|
||||
# Parsing is simple because fixed values are expected for method and
|
||||
# version and because path isn't checked. Since WebSocket software tends
|
||||
# to implement HTTP/1.1 strictly, there's little need for lenient parsing.
|
||||
|
||||
try:
|
||||
request_line = await read_line(stream)
|
||||
except EOFError as exc:
|
||||
raise EOFError("connection closed while reading HTTP request line") from exc
|
||||
|
||||
try:
|
||||
method, raw_path, version = request_line.split(b" ", 2)
|
||||
except ValueError: # not enough values to unpack (expected 3, got 1-2)
|
||||
raise ValueError(f"invalid HTTP request line: {d(request_line)}") from None
|
||||
|
||||
if method != b"GET":
|
||||
raise ValueError(f"unsupported HTTP method: {d(method)}")
|
||||
if version != b"HTTP/1.1":
|
||||
raise ValueError(f"unsupported HTTP version: {d(version)}")
|
||||
path = raw_path.decode("ascii", "surrogateescape")
|
||||
|
||||
headers = await read_headers(stream)
|
||||
|
||||
return path, headers
|
||||
|
||||
|
||||
async def read_response(stream: asyncio.StreamReader) -> tuple[int, str, Headers]:
|
||||
"""
|
||||
Read an HTTP/1.1 response and return ``(status_code, reason, headers)``.
|
||||
|
||||
``reason`` and ``headers`` are expected to contain only ASCII characters.
|
||||
Other characters are represented with surrogate escapes.
|
||||
|
||||
:func:`read_request` doesn't attempt to read the response body because
|
||||
WebSocket handshake responses don't have one. If the response contains a
|
||||
body, it may be read from ``stream`` after this coroutine returns.
|
||||
|
||||
Args:
|
||||
stream: Input to read the response from.
|
||||
|
||||
Raises:
|
||||
EOFError: If the connection is closed without a full HTTP response.
|
||||
SecurityError: If the response exceeds a security limit.
|
||||
ValueError: If the response isn't well formatted.
|
||||
|
||||
"""
|
||||
# https://datatracker.ietf.org/doc/html/rfc7230#section-3.1.2
|
||||
|
||||
# As in read_request, parsing is simple because a fixed value is expected
|
||||
# for version, status_code is a 3-digit number, and reason can be ignored.
|
||||
|
||||
try:
|
||||
status_line = await read_line(stream)
|
||||
except EOFError as exc:
|
||||
raise EOFError("connection closed while reading HTTP status line") from exc
|
||||
|
||||
try:
|
||||
version, raw_status_code, raw_reason = status_line.split(b" ", 2)
|
||||
except ValueError: # not enough values to unpack (expected 3, got 1-2)
|
||||
raise ValueError(f"invalid HTTP status line: {d(status_line)}") from None
|
||||
|
||||
if version != b"HTTP/1.1":
|
||||
raise ValueError(f"unsupported HTTP version: {d(version)}")
|
||||
try:
|
||||
status_code = int(raw_status_code)
|
||||
except ValueError: # invalid literal for int() with base 10
|
||||
raise ValueError(f"invalid HTTP status code: {d(raw_status_code)}") from None
|
||||
if not 100 <= status_code < 1000:
|
||||
raise ValueError(f"unsupported HTTP status code: {d(raw_status_code)}")
|
||||
if not _value_re.fullmatch(raw_reason):
|
||||
raise ValueError(f"invalid HTTP reason phrase: {d(raw_reason)}")
|
||||
reason = raw_reason.decode()
|
||||
|
||||
headers = await read_headers(stream)
|
||||
|
||||
return status_code, reason, headers
|
||||
|
||||
|
||||
async def read_headers(stream: asyncio.StreamReader) -> Headers:
|
||||
"""
|
||||
Read HTTP headers from ``stream``.
|
||||
|
||||
Non-ASCII characters are represented with surrogate escapes.
|
||||
|
||||
"""
|
||||
# https://datatracker.ietf.org/doc/html/rfc7230#section-3.2
|
||||
|
||||
# We don't attempt to support obsolete line folding.
|
||||
|
||||
headers = Headers()
|
||||
for _ in range(MAX_NUM_HEADERS + 1):
|
||||
try:
|
||||
line = await read_line(stream)
|
||||
except EOFError as exc:
|
||||
raise EOFError("connection closed while reading HTTP headers") from exc
|
||||
if line == b"":
|
||||
break
|
||||
|
||||
try:
|
||||
raw_name, raw_value = line.split(b":", 1)
|
||||
except ValueError: # not enough values to unpack (expected 2, got 1)
|
||||
raise ValueError(f"invalid HTTP header line: {d(line)}") from None
|
||||
if not _token_re.fullmatch(raw_name):
|
||||
raise ValueError(f"invalid HTTP header name: {d(raw_name)}")
|
||||
raw_value = raw_value.strip(b" \t")
|
||||
if not _value_re.fullmatch(raw_value):
|
||||
raise ValueError(f"invalid HTTP header value: {d(raw_value)}")
|
||||
|
||||
name = raw_name.decode("ascii") # guaranteed to be ASCII at this point
|
||||
value = raw_value.decode("ascii", "surrogateescape")
|
||||
headers[name] = value
|
||||
|
||||
else:
|
||||
raise SecurityError("too many HTTP headers")
|
||||
|
||||
return headers
|
||||
|
||||
|
||||
async def read_line(stream: asyncio.StreamReader) -> bytes:
|
||||
"""
|
||||
Read a single line from ``stream``.
|
||||
|
||||
CRLF is stripped from the return value.
|
||||
|
||||
"""
|
||||
# Security: this is bounded by the StreamReader's limit (default = 32 KiB).
|
||||
line = await stream.readline()
|
||||
# Security: this guarantees header values are small (hard-coded = 8 KiB)
|
||||
if len(line) > MAX_LINE_LENGTH:
|
||||
raise SecurityError("line too long")
|
||||
# Not mandatory but safe - https://datatracker.ietf.org/doc/html/rfc7230#section-3.5
|
||||
if not line.endswith(b"\r\n"):
|
||||
raise EOFError("line without CRLF")
|
||||
return line[:-2]
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -1,732 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Generator, Union
|
||||
|
||||
from .exceptions import (
|
||||
ConnectionClosed,
|
||||
ConnectionClosedError,
|
||||
ConnectionClosedOK,
|
||||
InvalidState,
|
||||
PayloadTooBig,
|
||||
ProtocolError,
|
||||
)
|
||||
from .extensions import Extension
|
||||
from .frames import (
|
||||
OK_CLOSE_CODES,
|
||||
OP_BINARY,
|
||||
OP_CLOSE,
|
||||
OP_CONT,
|
||||
OP_PING,
|
||||
OP_PONG,
|
||||
OP_TEXT,
|
||||
Close,
|
||||
CloseCode,
|
||||
Frame,
|
||||
)
|
||||
from .http11 import Request, Response
|
||||
from .streams import StreamReader
|
||||
from .typing import LoggerLike, Origin, Subprotocol
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Protocol",
|
||||
"Side",
|
||||
"State",
|
||||
"SEND_EOF",
|
||||
]
|
||||
|
||||
# Change to Request | Response | Frame when dropping Python < 3.10.
|
||||
Event = Union[Request, Response, Frame]
|
||||
"""Events that :meth:`~Protocol.events_received` may return."""
|
||||
|
||||
|
||||
class Side(enum.IntEnum):
|
||||
"""A WebSocket connection is either a server or a client."""
|
||||
|
||||
SERVER, CLIENT = range(2)
|
||||
|
||||
|
||||
SERVER = Side.SERVER
|
||||
CLIENT = Side.CLIENT
|
||||
|
||||
|
||||
class State(enum.IntEnum):
|
||||
"""A WebSocket connection is in one of these four states."""
|
||||
|
||||
CONNECTING, OPEN, CLOSING, CLOSED = range(4)
|
||||
|
||||
|
||||
CONNECTING = State.CONNECTING
|
||||
OPEN = State.OPEN
|
||||
CLOSING = State.CLOSING
|
||||
CLOSED = State.CLOSED
|
||||
|
||||
|
||||
SEND_EOF = b""
|
||||
"""Sentinel signaling that the TCP connection must be half-closed."""
|
||||
|
||||
|
||||
class Protocol:
|
||||
"""
|
||||
Sans-I/O implementation of a WebSocket connection.
|
||||
|
||||
Args:
|
||||
side: :attr:`~Side.CLIENT` or :attr:`~Side.SERVER`.
|
||||
state: Initial state of the WebSocket connection.
|
||||
max_size: Maximum size of incoming messages in bytes;
|
||||
:obj:`None` disables the limit.
|
||||
logger: Logger for this connection; depending on ``side``,
|
||||
defaults to ``logging.getLogger("websockets.client")``
|
||||
or ``logging.getLogger("websockets.server")``;
|
||||
see the :doc:`logging guide <../../topics/logging>` for details.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
side: Side,
|
||||
*,
|
||||
state: State = OPEN,
|
||||
max_size: int | None = 2**20,
|
||||
logger: LoggerLike | None = None,
|
||||
) -> None:
|
||||
# Unique identifier. For logs.
|
||||
self.id: uuid.UUID = uuid.uuid4()
|
||||
"""Unique identifier of the connection. Useful in logs."""
|
||||
|
||||
# Logger or LoggerAdapter for this connection.
|
||||
if logger is None:
|
||||
logger = logging.getLogger(f"websockets.{side.name.lower()}")
|
||||
self.logger: LoggerLike = logger
|
||||
"""Logger for this connection."""
|
||||
|
||||
# Track if DEBUG is enabled. Shortcut logging calls if it isn't.
|
||||
self.debug = logger.isEnabledFor(logging.DEBUG)
|
||||
|
||||
# Connection side. CLIENT or SERVER.
|
||||
self.side = side
|
||||
|
||||
# Connection state. Initially OPEN because subclasses handle CONNECTING.
|
||||
self.state = state
|
||||
|
||||
# Maximum size of incoming messages in bytes.
|
||||
self.max_size = max_size
|
||||
|
||||
# Current size of incoming message in bytes. Only set while reading a
|
||||
# fragmented message i.e. a data frames with the FIN bit not set.
|
||||
self.cur_size: int | None = None
|
||||
|
||||
# True while sending a fragmented message i.e. a data frames with the
|
||||
# FIN bit not set.
|
||||
self.expect_continuation_frame = False
|
||||
|
||||
# WebSocket protocol parameters.
|
||||
self.origin: Origin | None = None
|
||||
self.extensions: list[Extension] = []
|
||||
self.subprotocol: Subprotocol | None = None
|
||||
|
||||
# Close code and reason, set when a close frame is sent or received.
|
||||
self.close_rcvd: Close | None = None
|
||||
self.close_sent: Close | None = None
|
||||
self.close_rcvd_then_sent: bool | None = None
|
||||
|
||||
# Track if an exception happened during the handshake.
|
||||
self.handshake_exc: Exception | None = None
|
||||
"""
|
||||
Exception to raise if the opening handshake failed.
|
||||
|
||||
:obj:`None` if the opening handshake succeeded.
|
||||
|
||||
"""
|
||||
|
||||
# Track if send_eof() was called.
|
||||
self.eof_sent = False
|
||||
|
||||
# Parser state.
|
||||
self.reader = StreamReader()
|
||||
self.events: list[Event] = []
|
||||
self.writes: list[bytes] = []
|
||||
self.parser = self.parse()
|
||||
next(self.parser) # start coroutine
|
||||
self.parser_exc: Exception | None = None
|
||||
|
||||
@property
|
||||
def state(self) -> State:
|
||||
"""
|
||||
State of the WebSocket connection.
|
||||
|
||||
Defined in 4.1, 4.2, 7.1.3, and 7.1.4 of :rfc:`6455`.
|
||||
|
||||
"""
|
||||
return self._state
|
||||
|
||||
@state.setter
|
||||
def state(self, state: State) -> None:
|
||||
if self.debug:
|
||||
self.logger.debug("= connection is %s", state.name)
|
||||
self._state = state
|
||||
|
||||
@property
|
||||
def close_code(self) -> int | None:
|
||||
"""
|
||||
`WebSocket close code`_.
|
||||
|
||||
.. _WebSocket close code:
|
||||
https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.5
|
||||
|
||||
:obj:`None` if the connection isn't closed yet.
|
||||
|
||||
"""
|
||||
if self.state is not CLOSED:
|
||||
return None
|
||||
elif self.close_rcvd is None:
|
||||
return CloseCode.ABNORMAL_CLOSURE
|
||||
else:
|
||||
return self.close_rcvd.code
|
||||
|
||||
@property
|
||||
def close_reason(self) -> str | None:
|
||||
"""
|
||||
`WebSocket close reason`_.
|
||||
|
||||
.. _WebSocket close reason:
|
||||
https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.6
|
||||
|
||||
:obj:`None` if the connection isn't closed yet.
|
||||
|
||||
"""
|
||||
if self.state is not CLOSED:
|
||||
return None
|
||||
elif self.close_rcvd is None:
|
||||
return ""
|
||||
else:
|
||||
return self.close_rcvd.reason
|
||||
|
||||
@property
|
||||
def close_exc(self) -> ConnectionClosed:
|
||||
"""
|
||||
Exception to raise when trying to interact with a closed connection.
|
||||
|
||||
Don't raise this exception while the connection :attr:`state`
|
||||
is :attr:`~websockets.protocol.State.CLOSING`; wait until
|
||||
it's :attr:`~websockets.protocol.State.CLOSED`.
|
||||
|
||||
Indeed, the exception includes the close code and reason, which are
|
||||
known only once the connection is closed.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the connection isn't closed yet.
|
||||
|
||||
"""
|
||||
assert self.state is CLOSED, "connection isn't closed yet"
|
||||
exc_type: type[ConnectionClosed]
|
||||
if (
|
||||
self.close_rcvd is not None
|
||||
and self.close_sent is not None
|
||||
and self.close_rcvd.code in OK_CLOSE_CODES
|
||||
and self.close_sent.code in OK_CLOSE_CODES
|
||||
):
|
||||
exc_type = ConnectionClosedOK
|
||||
else:
|
||||
exc_type = ConnectionClosedError
|
||||
exc: ConnectionClosed = exc_type(
|
||||
self.close_rcvd,
|
||||
self.close_sent,
|
||||
self.close_rcvd_then_sent,
|
||||
)
|
||||
# Chain to the exception raised in the parser, if any.
|
||||
exc.__cause__ = self.parser_exc
|
||||
return exc
|
||||
|
||||
# Public methods for receiving data.
|
||||
|
||||
def receive_data(self, data: bytes) -> None:
|
||||
"""
|
||||
Receive data from the network.
|
||||
|
||||
After calling this method:
|
||||
|
||||
- You must call :meth:`data_to_send` and send this data to the network.
|
||||
- You should call :meth:`events_received` and process resulting events.
|
||||
|
||||
Raises:
|
||||
EOFError: If :meth:`receive_eof` was called earlier.
|
||||
|
||||
"""
|
||||
self.reader.feed_data(data)
|
||||
next(self.parser)
|
||||
|
||||
def receive_eof(self) -> None:
|
||||
"""
|
||||
Receive the end of the data stream from the network.
|
||||
|
||||
After calling this method:
|
||||
|
||||
- You must call :meth:`data_to_send` and send this data to the network;
|
||||
it will return ``[b""]``, signaling the end of the stream, or ``[]``.
|
||||
- You aren't expected to call :meth:`events_received`; it won't return
|
||||
any new events.
|
||||
|
||||
:meth:`receive_eof` is idempotent.
|
||||
|
||||
"""
|
||||
if self.reader.eof:
|
||||
return
|
||||
self.reader.feed_eof()
|
||||
next(self.parser)
|
||||
|
||||
# Public methods for sending events.
|
||||
|
||||
def send_continuation(self, data: bytes, fin: bool) -> None:
|
||||
"""
|
||||
Send a `Continuation frame`_.
|
||||
|
||||
.. _Continuation frame:
|
||||
https://datatracker.ietf.org/doc/html/rfc6455#section-5.6
|
||||
|
||||
Parameters:
|
||||
data: payload containing the same kind of data
|
||||
as the initial frame.
|
||||
fin: FIN bit; set it to :obj:`True` if this is the last frame
|
||||
of a fragmented message and to :obj:`False` otherwise.
|
||||
|
||||
Raises:
|
||||
ProtocolError: If a fragmented message isn't in progress.
|
||||
|
||||
"""
|
||||
if not self.expect_continuation_frame:
|
||||
raise ProtocolError("unexpected continuation frame")
|
||||
if self._state is not OPEN:
|
||||
raise InvalidState(f"connection is {self.state.name.lower()}")
|
||||
self.expect_continuation_frame = not fin
|
||||
self.send_frame(Frame(OP_CONT, data, fin))
|
||||
|
||||
def send_text(self, data: bytes, fin: bool = True) -> None:
|
||||
"""
|
||||
Send a `Text frame`_.
|
||||
|
||||
.. _Text frame:
|
||||
https://datatracker.ietf.org/doc/html/rfc6455#section-5.6
|
||||
|
||||
Parameters:
|
||||
data: payload containing text encoded with UTF-8.
|
||||
fin: FIN bit; set it to :obj:`False` if this is the first frame of
|
||||
a fragmented message.
|
||||
|
||||
Raises:
|
||||
ProtocolError: If a fragmented message is in progress.
|
||||
|
||||
"""
|
||||
if self.expect_continuation_frame:
|
||||
raise ProtocolError("expected a continuation frame")
|
||||
if self._state is not OPEN:
|
||||
raise InvalidState(f"connection is {self.state.name.lower()}")
|
||||
self.expect_continuation_frame = not fin
|
||||
self.send_frame(Frame(OP_TEXT, data, fin))
|
||||
|
||||
def send_binary(self, data: bytes, fin: bool = True) -> None:
|
||||
"""
|
||||
Send a `Binary frame`_.
|
||||
|
||||
.. _Binary frame:
|
||||
https://datatracker.ietf.org/doc/html/rfc6455#section-5.6
|
||||
|
||||
Parameters:
|
||||
data: payload containing arbitrary binary data.
|
||||
fin: FIN bit; set it to :obj:`False` if this is the first frame of
|
||||
a fragmented message.
|
||||
|
||||
Raises:
|
||||
ProtocolError: If a fragmented message is in progress.
|
||||
|
||||
"""
|
||||
if self.expect_continuation_frame:
|
||||
raise ProtocolError("expected a continuation frame")
|
||||
if self._state is not OPEN:
|
||||
raise InvalidState(f"connection is {self.state.name.lower()}")
|
||||
self.expect_continuation_frame = not fin
|
||||
self.send_frame(Frame(OP_BINARY, data, fin))
|
||||
|
||||
def send_close(self, code: int | None = None, reason: str = "") -> None:
|
||||
"""
|
||||
Send a `Close frame`_.
|
||||
|
||||
.. _Close frame:
|
||||
https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.1
|
||||
|
||||
Parameters:
|
||||
code: close code.
|
||||
reason: close reason.
|
||||
|
||||
Raises:
|
||||
ProtocolError: If the code isn't valid or if a reason is provided
|
||||
without a code.
|
||||
|
||||
"""
|
||||
# While RFC 6455 doesn't rule out sending more than one close Frame,
|
||||
# websockets is conservative in what it sends and doesn't allow that.
|
||||
if self._state is not OPEN:
|
||||
raise InvalidState(f"connection is {self.state.name.lower()}")
|
||||
if code is None:
|
||||
if reason != "":
|
||||
raise ProtocolError("cannot send a reason without a code")
|
||||
close = Close(CloseCode.NO_STATUS_RCVD, "")
|
||||
data = b""
|
||||
else:
|
||||
close = Close(code, reason)
|
||||
data = close.serialize()
|
||||
# 7.1.3. The WebSocket Closing Handshake is Started
|
||||
self.send_frame(Frame(OP_CLOSE, data))
|
||||
# Since the state is OPEN, no close frame was received yet.
|
||||
# As a consequence, self.close_rcvd_then_sent remains None.
|
||||
assert self.close_rcvd is None
|
||||
self.close_sent = close
|
||||
self.state = CLOSING
|
||||
|
||||
def send_ping(self, data: bytes) -> None:
|
||||
"""
|
||||
Send a `Ping frame`_.
|
||||
|
||||
.. _Ping frame:
|
||||
https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.2
|
||||
|
||||
Parameters:
|
||||
data: payload containing arbitrary binary data.
|
||||
|
||||
"""
|
||||
# RFC 6455 allows control frames after starting the closing handshake.
|
||||
if self._state is not OPEN and self._state is not CLOSING:
|
||||
raise InvalidState(f"connection is {self.state.name.lower()}")
|
||||
self.send_frame(Frame(OP_PING, data))
|
||||
|
||||
def send_pong(self, data: bytes) -> None:
|
||||
"""
|
||||
Send a `Pong frame`_.
|
||||
|
||||
.. _Pong frame:
|
||||
https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.3
|
||||
|
||||
Parameters:
|
||||
data: payload containing arbitrary binary data.
|
||||
|
||||
"""
|
||||
# RFC 6455 allows control frames after starting the closing handshake.
|
||||
if self._state is not OPEN and self._state is not CLOSING:
|
||||
raise InvalidState(f"connection is {self.state.name.lower()}")
|
||||
self.send_frame(Frame(OP_PONG, data))
|
||||
|
||||
def fail(self, code: int, reason: str = "") -> None:
|
||||
"""
|
||||
`Fail the WebSocket connection`_.
|
||||
|
||||
.. _Fail the WebSocket connection:
|
||||
https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.7
|
||||
|
||||
Parameters:
|
||||
code: close code
|
||||
reason: close reason
|
||||
|
||||
Raises:
|
||||
ProtocolError: If the code isn't valid.
|
||||
"""
|
||||
# 7.1.7. Fail the WebSocket Connection
|
||||
|
||||
# Send a close frame when the state is OPEN (a close frame was already
|
||||
# sent if it's CLOSING), except when failing the connection because
|
||||
# of an error reading from or writing to the network.
|
||||
if self.state is OPEN:
|
||||
if code != CloseCode.ABNORMAL_CLOSURE:
|
||||
close = Close(code, reason)
|
||||
data = close.serialize()
|
||||
self.send_frame(Frame(OP_CLOSE, data))
|
||||
self.close_sent = close
|
||||
# If recv_messages() raised an exception upon receiving a close
|
||||
# frame but before echoing it, then close_rcvd is not None even
|
||||
# though the state is OPEN. This happens when the connection is
|
||||
# closed while receiving a fragmented message.
|
||||
if self.close_rcvd is not None:
|
||||
self.close_rcvd_then_sent = True
|
||||
self.state = CLOSING
|
||||
|
||||
# When failing the connection, a server closes the TCP connection
|
||||
# without waiting for the client to complete the handshake, while a
|
||||
# client waits for the server to close the TCP connection, possibly
|
||||
# after sending a close frame that the client will ignore.
|
||||
if self.side is SERVER and not self.eof_sent:
|
||||
self.send_eof()
|
||||
|
||||
# 7.1.7. Fail the WebSocket Connection "An endpoint MUST NOT continue
|
||||
# to attempt to process data(including a responding Close frame) from
|
||||
# the remote endpoint after being instructed to _Fail the WebSocket
|
||||
# Connection_."
|
||||
self.parser = self.discard()
|
||||
next(self.parser) # start coroutine
|
||||
|
||||
# Public method for getting incoming events after receiving data.
|
||||
|
||||
def events_received(self) -> list[Event]:
|
||||
"""
|
||||
Fetch events generated from data received from the network.
|
||||
|
||||
Call this method immediately after any of the ``receive_*()`` methods.
|
||||
|
||||
Process resulting events, likely by passing them to the application.
|
||||
|
||||
Returns:
|
||||
Events read from the connection.
|
||||
"""
|
||||
events, self.events = self.events, []
|
||||
return events
|
||||
|
||||
# Public method for getting outgoing data after receiving data or sending events.
|
||||
|
||||
def data_to_send(self) -> list[bytes]:
|
||||
"""
|
||||
Obtain data to send to the network.
|
||||
|
||||
Call this method immediately after any of the ``receive_*()``,
|
||||
``send_*()``, or :meth:`fail` methods.
|
||||
|
||||
Write resulting data to the connection.
|
||||
|
||||
The empty bytestring :data:`~websockets.protocol.SEND_EOF` signals
|
||||
the end of the data stream. When you receive it, half-close the TCP
|
||||
connection.
|
||||
|
||||
Returns:
|
||||
Data to write to the connection.
|
||||
|
||||
"""
|
||||
writes, self.writes = self.writes, []
|
||||
return writes
|
||||
|
||||
def close_expected(self) -> bool:
|
||||
"""
|
||||
Tell if the TCP connection is expected to close soon.
|
||||
|
||||
Call this method immediately after any of the ``receive_*()``,
|
||||
``send_close()``, or :meth:`fail` methods.
|
||||
|
||||
If it returns :obj:`True`, schedule closing the TCP connection after a
|
||||
short timeout if the other side hasn't already closed it.
|
||||
|
||||
Returns:
|
||||
Whether the TCP connection is expected to close soon.
|
||||
|
||||
"""
|
||||
# We expect a TCP close if and only if we sent a close frame:
|
||||
# * Normal closure: once we send a close frame, we expect a TCP close:
|
||||
# server waits for client to complete the TCP closing handshake;
|
||||
# client waits for server to initiate the TCP closing handshake.
|
||||
# * Abnormal closure: we always send a close frame and the same logic
|
||||
# applies, except on EOFError where we don't send a close frame
|
||||
# because we already received the TCP close, so we don't expect it.
|
||||
# We already got a TCP Close if and only if the state is CLOSED.
|
||||
return self.state is CLOSING or self.handshake_exc is not None
|
||||
|
||||
# Private methods for receiving data.
|
||||
|
||||
def parse(self) -> Generator[None, None, None]:
|
||||
"""
|
||||
Parse incoming data into frames.
|
||||
|
||||
:meth:`receive_data` and :meth:`receive_eof` run this generator
|
||||
coroutine until it needs more data or reaches EOF.
|
||||
|
||||
:meth:`parse` never raises an exception. Instead, it sets the
|
||||
:attr:`parser_exc` and yields control.
|
||||
|
||||
"""
|
||||
try:
|
||||
while True:
|
||||
if (yield from self.reader.at_eof()):
|
||||
if self.debug:
|
||||
self.logger.debug("< EOF")
|
||||
# If the WebSocket connection is closed cleanly, with a
|
||||
# closing handhshake, recv_frame() substitutes parse()
|
||||
# with discard(). This branch is reached only when the
|
||||
# connection isn't closed cleanly.
|
||||
raise EOFError("unexpected end of stream")
|
||||
|
||||
if self.max_size is None:
|
||||
max_size = None
|
||||
elif self.cur_size is None:
|
||||
max_size = self.max_size
|
||||
else:
|
||||
max_size = self.max_size - self.cur_size
|
||||
|
||||
# During a normal closure, execution ends here on the next
|
||||
# iteration of the loop after receiving a close frame. At
|
||||
# this point, recv_frame() replaced parse() by discard().
|
||||
frame = yield from Frame.parse(
|
||||
self.reader.read_exact,
|
||||
mask=self.side is SERVER,
|
||||
max_size=max_size,
|
||||
extensions=self.extensions,
|
||||
)
|
||||
|
||||
if self.debug:
|
||||
self.logger.debug("< %s", frame)
|
||||
|
||||
self.recv_frame(frame)
|
||||
|
||||
except ProtocolError as exc:
|
||||
self.fail(CloseCode.PROTOCOL_ERROR, str(exc))
|
||||
self.parser_exc = exc
|
||||
|
||||
except EOFError as exc:
|
||||
self.fail(CloseCode.ABNORMAL_CLOSURE, str(exc))
|
||||
self.parser_exc = exc
|
||||
|
||||
except UnicodeDecodeError as exc:
|
||||
self.fail(CloseCode.INVALID_DATA, f"{exc.reason} at position {exc.start}")
|
||||
self.parser_exc = exc
|
||||
|
||||
except PayloadTooBig as exc:
|
||||
self.fail(CloseCode.MESSAGE_TOO_BIG, str(exc))
|
||||
self.parser_exc = exc
|
||||
|
||||
except Exception as exc:
|
||||
self.logger.error("parser failed", exc_info=True)
|
||||
# Don't include exception details, which may be security-sensitive.
|
||||
self.fail(CloseCode.INTERNAL_ERROR)
|
||||
self.parser_exc = exc
|
||||
|
||||
# During an abnormal closure, execution ends here after catching an
|
||||
# exception. At this point, fail() replaced parse() by discard().
|
||||
yield
|
||||
raise AssertionError("parse() shouldn't step after error")
|
||||
|
||||
def discard(self) -> Generator[None, None, None]:
|
||||
"""
|
||||
Discard incoming data.
|
||||
|
||||
This coroutine replaces :meth:`parse`:
|
||||
|
||||
- after receiving a close frame, during a normal closure (1.4);
|
||||
- after sending a close frame, during an abnormal closure (7.1.7).
|
||||
|
||||
"""
|
||||
# After the opening handshake completes, the server closes the TCP
|
||||
# connection in the same circumstances where discard() replaces parse().
|
||||
# The client closes it when it receives EOF from the server or times
|
||||
# out. (The latter case cannot be handled in this Sans-I/O layer.)
|
||||
assert (self.state == CONNECTING or self.side is SERVER) == (self.eof_sent)
|
||||
while not (yield from self.reader.at_eof()):
|
||||
self.reader.discard()
|
||||
if self.debug:
|
||||
self.logger.debug("< EOF")
|
||||
# A server closes the TCP connection immediately, while a client
|
||||
# waits for the server to close the TCP connection.
|
||||
if self.state != CONNECTING and self.side is CLIENT:
|
||||
self.send_eof()
|
||||
self.state = CLOSED
|
||||
# If discard() completes normally, execution ends here.
|
||||
yield
|
||||
# Once the reader reaches EOF, its feed_data/eof() methods raise an
|
||||
# error, so our receive_data/eof() methods don't step the generator.
|
||||
raise AssertionError("discard() shouldn't step after EOF")
|
||||
|
||||
def recv_frame(self, frame: Frame) -> None:
|
||||
"""
|
||||
Process an incoming frame.
|
||||
|
||||
"""
|
||||
if frame.opcode is OP_TEXT or frame.opcode is OP_BINARY:
|
||||
if self.cur_size is not None:
|
||||
raise ProtocolError("expected a continuation frame")
|
||||
if frame.fin:
|
||||
self.cur_size = None
|
||||
else:
|
||||
self.cur_size = len(frame.data)
|
||||
|
||||
elif frame.opcode is OP_CONT:
|
||||
if self.cur_size is None:
|
||||
raise ProtocolError("unexpected continuation frame")
|
||||
if frame.fin:
|
||||
self.cur_size = None
|
||||
else:
|
||||
self.cur_size += len(frame.data)
|
||||
|
||||
elif frame.opcode is OP_PING:
|
||||
# 5.5.2. Ping: "Upon receipt of a Ping frame, an endpoint MUST
|
||||
# send a Pong frame in response"
|
||||
pong_frame = Frame(OP_PONG, frame.data)
|
||||
self.send_frame(pong_frame)
|
||||
|
||||
elif frame.opcode is OP_PONG:
|
||||
# 5.5.3 Pong: "A response to an unsolicited Pong frame is not
|
||||
# expected."
|
||||
pass
|
||||
|
||||
elif frame.opcode is OP_CLOSE:
|
||||
# 7.1.5. The WebSocket Connection Close Code
|
||||
# 7.1.6. The WebSocket Connection Close Reason
|
||||
self.close_rcvd = Close.parse(frame.data)
|
||||
if self.state is CLOSING:
|
||||
assert self.close_sent is not None
|
||||
self.close_rcvd_then_sent = False
|
||||
|
||||
if self.cur_size is not None:
|
||||
raise ProtocolError("incomplete fragmented message")
|
||||
|
||||
# 5.5.1 Close: "If an endpoint receives a Close frame and did
|
||||
# not previously send a Close frame, the endpoint MUST send a
|
||||
# Close frame in response. (When sending a Close frame in
|
||||
# response, the endpoint typically echos the status code it
|
||||
# received.)"
|
||||
|
||||
if self.state is OPEN:
|
||||
# Echo the original data instead of re-serializing it with
|
||||
# Close.serialize() because that fails when the close frame
|
||||
# is empty and Close.parse() synthesizes a 1005 close code.
|
||||
# The rest is identical to send_close().
|
||||
self.send_frame(Frame(OP_CLOSE, frame.data))
|
||||
self.close_sent = self.close_rcvd
|
||||
self.close_rcvd_then_sent = True
|
||||
self.state = CLOSING
|
||||
|
||||
# 7.1.2. Start the WebSocket Closing Handshake: "Once an
|
||||
# endpoint has both sent and received a Close control frame,
|
||||
# that endpoint SHOULD _Close the WebSocket Connection_"
|
||||
|
||||
# A server closes the TCP connection immediately, while a client
|
||||
# waits for the server to close the TCP connection.
|
||||
if self.side is SERVER:
|
||||
self.send_eof()
|
||||
|
||||
# 1.4. Closing Handshake: "after receiving a control frame
|
||||
# indicating the connection should be closed, a peer discards
|
||||
# any further data received."
|
||||
# RFC 6455 allows reading Ping and Pong frames after a Close frame.
|
||||
# However, that doesn't seem useful; websockets doesn't support it.
|
||||
self.parser = self.discard()
|
||||
next(self.parser) # start coroutine
|
||||
|
||||
else:
|
||||
# This can't happen because Frame.parse() validates opcodes.
|
||||
raise AssertionError(f"unexpected opcode: {frame.opcode:02x}")
|
||||
|
||||
self.events.append(frame)
|
||||
|
||||
# Private methods for sending events.
|
||||
|
||||
def send_frame(self, frame: Frame) -> None:
|
||||
if self.debug:
|
||||
self.logger.debug("> %s", frame)
|
||||
self.writes.append(
|
||||
frame.serialize(
|
||||
mask=self.side is CLIENT,
|
||||
extensions=self.extensions,
|
||||
)
|
||||
)
|
||||
|
||||
def send_eof(self) -> None:
|
||||
assert not self.eof_sent
|
||||
self.eof_sent = True
|
||||
if self.debug:
|
||||
self.logger.debug("> EOF")
|
||||
self.writes.append(SEND_EOF)
|
@ -1,587 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import binascii
|
||||
import email.utils
|
||||
import http
|
||||
import warnings
|
||||
from typing import Any, Callable, Generator, Sequence, cast
|
||||
|
||||
from .datastructures import Headers, MultipleValuesError
|
||||
from .exceptions import (
|
||||
InvalidHandshake,
|
||||
InvalidHeader,
|
||||
InvalidHeaderValue,
|
||||
InvalidOrigin,
|
||||
InvalidStatus,
|
||||
InvalidUpgrade,
|
||||
NegotiationError,
|
||||
)
|
||||
from .extensions import Extension, ServerExtensionFactory
|
||||
from .headers import (
|
||||
build_extension,
|
||||
parse_connection,
|
||||
parse_extension,
|
||||
parse_subprotocol,
|
||||
parse_upgrade,
|
||||
)
|
||||
from .http11 import Request, Response
|
||||
from .protocol import CONNECTING, OPEN, SERVER, Protocol, State
|
||||
from .typing import (
|
||||
ConnectionOption,
|
||||
ExtensionHeader,
|
||||
LoggerLike,
|
||||
Origin,
|
||||
StatusLike,
|
||||
Subprotocol,
|
||||
UpgradeProtocol,
|
||||
)
|
||||
from .utils import accept_key
|
||||
|
||||
|
||||
# See #940 for why lazy_import isn't used here for backwards compatibility.
|
||||
# See #1400 for why listing compatibility imports in __all__ helps PyCharm.
|
||||
from .legacy.server import * # isort:skip # noqa: I001
|
||||
from .legacy.server import __all__ as legacy__all__
|
||||
|
||||
|
||||
__all__ = ["ServerProtocol"] + legacy__all__
|
||||
|
||||
|
||||
class ServerProtocol(Protocol):
|
||||
"""
|
||||
Sans-I/O implementation of a WebSocket server connection.
|
||||
|
||||
Args:
|
||||
origins: Acceptable values of the ``Origin`` header; include
|
||||
:obj:`None` in the list if the lack of an origin is acceptable.
|
||||
This is useful for defending against Cross-Site WebSocket
|
||||
Hijacking attacks.
|
||||
extensions: List of supported extensions, in order in which they
|
||||
should be tried.
|
||||
subprotocols: List of supported subprotocols, in order of decreasing
|
||||
preference.
|
||||
select_subprotocol: Callback for selecting a subprotocol among
|
||||
those supported by the client and the server. It has the same
|
||||
signature as the :meth:`select_subprotocol` method, including a
|
||||
:class:`ServerProtocol` instance as first argument.
|
||||
state: Initial state of the WebSocket connection.
|
||||
max_size: Maximum size of incoming messages in bytes;
|
||||
:obj:`None` disables the limit.
|
||||
logger: Logger for this connection;
|
||||
defaults to ``logging.getLogger("websockets.server")``;
|
||||
see the :doc:`logging guide <../../topics/logging>` for details.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
origins: Sequence[Origin | None] | None = None,
|
||||
extensions: Sequence[ServerExtensionFactory] | None = None,
|
||||
subprotocols: Sequence[Subprotocol] | None = None,
|
||||
select_subprotocol: (
|
||||
Callable[
|
||||
[ServerProtocol, Sequence[Subprotocol]],
|
||||
Subprotocol | None,
|
||||
]
|
||||
| None
|
||||
) = None,
|
||||
state: State = CONNECTING,
|
||||
max_size: int | None = 2**20,
|
||||
logger: LoggerLike | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
side=SERVER,
|
||||
state=state,
|
||||
max_size=max_size,
|
||||
logger=logger,
|
||||
)
|
||||
self.origins = origins
|
||||
self.available_extensions = extensions
|
||||
self.available_subprotocols = subprotocols
|
||||
if select_subprotocol is not None:
|
||||
# Bind select_subprotocol then shadow self.select_subprotocol.
|
||||
# Use setattr to work around https://github.com/python/mypy/issues/2427.
|
||||
setattr(
|
||||
self,
|
||||
"select_subprotocol",
|
||||
select_subprotocol.__get__(self, self.__class__),
|
||||
)
|
||||
|
||||
def accept(self, request: Request) -> Response:
|
||||
"""
|
||||
Create a handshake response to accept the connection.
|
||||
|
||||
If the handshake request is valid and the handshake successful,
|
||||
:meth:`accept` returns an HTTP response with status code 101.
|
||||
|
||||
Else, it returns an HTTP response with another status code. This rejects
|
||||
the connection, like :meth:`reject` would.
|
||||
|
||||
You must send the handshake response with :meth:`send_response`.
|
||||
|
||||
You may modify the response before sending it, typically by adding HTTP
|
||||
headers.
|
||||
|
||||
Args:
|
||||
request: WebSocket handshake request received from the client.
|
||||
|
||||
Returns:
|
||||
WebSocket handshake response or HTTP response to send to the client.
|
||||
|
||||
"""
|
||||
try:
|
||||
(
|
||||
accept_header,
|
||||
extensions_header,
|
||||
protocol_header,
|
||||
) = self.process_request(request)
|
||||
except InvalidOrigin as exc:
|
||||
request._exception = exc
|
||||
self.handshake_exc = exc
|
||||
if self.debug:
|
||||
self.logger.debug("! invalid origin", exc_info=True)
|
||||
return self.reject(
|
||||
http.HTTPStatus.FORBIDDEN,
|
||||
f"Failed to open a WebSocket connection: {exc}.\n",
|
||||
)
|
||||
except InvalidUpgrade as exc:
|
||||
request._exception = exc
|
||||
self.handshake_exc = exc
|
||||
if self.debug:
|
||||
self.logger.debug("! invalid upgrade", exc_info=True)
|
||||
response = self.reject(
|
||||
http.HTTPStatus.UPGRADE_REQUIRED,
|
||||
(
|
||||
f"Failed to open a WebSocket connection: {exc}.\n"
|
||||
f"\n"
|
||||
f"You cannot access a WebSocket server directly "
|
||||
f"with a browser. You need a WebSocket client.\n"
|
||||
),
|
||||
)
|
||||
response.headers["Upgrade"] = "websocket"
|
||||
return response
|
||||
except InvalidHandshake as exc:
|
||||
request._exception = exc
|
||||
self.handshake_exc = exc
|
||||
if self.debug:
|
||||
self.logger.debug("! invalid handshake", exc_info=True)
|
||||
exc_chain = cast(BaseException, exc)
|
||||
exc_str = f"{exc_chain}"
|
||||
while exc_chain.__cause__ is not None:
|
||||
exc_chain = exc_chain.__cause__
|
||||
exc_str += f"; {exc_chain}"
|
||||
return self.reject(
|
||||
http.HTTPStatus.BAD_REQUEST,
|
||||
f"Failed to open a WebSocket connection: {exc_str}.\n",
|
||||
)
|
||||
except Exception as exc:
|
||||
# Handle exceptions raised by user-provided select_subprotocol and
|
||||
# unexpected errors.
|
||||
request._exception = exc
|
||||
self.handshake_exc = exc
|
||||
self.logger.error("opening handshake failed", exc_info=True)
|
||||
return self.reject(
|
||||
http.HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||
(
|
||||
"Failed to open a WebSocket connection.\n"
|
||||
"See server log for more information.\n"
|
||||
),
|
||||
)
|
||||
|
||||
headers = Headers()
|
||||
|
||||
headers["Date"] = email.utils.formatdate(usegmt=True)
|
||||
|
||||
headers["Upgrade"] = "websocket"
|
||||
headers["Connection"] = "Upgrade"
|
||||
headers["Sec-WebSocket-Accept"] = accept_header
|
||||
|
||||
if extensions_header is not None:
|
||||
headers["Sec-WebSocket-Extensions"] = extensions_header
|
||||
|
||||
if protocol_header is not None:
|
||||
headers["Sec-WebSocket-Protocol"] = protocol_header
|
||||
|
||||
return Response(101, "Switching Protocols", headers)
|
||||
|
||||
def process_request(
|
||||
self,
|
||||
request: Request,
|
||||
) -> tuple[str, str | None, str | None]:
|
||||
"""
|
||||
Check a handshake request and negotiate extensions and subprotocol.
|
||||
|
||||
This function doesn't verify that the request is an HTTP/1.1 or higher
|
||||
GET request and doesn't check the ``Host`` header. These controls are
|
||||
usually performed earlier in the HTTP request handling code. They're
|
||||
the responsibility of the caller.
|
||||
|
||||
Args:
|
||||
request: WebSocket handshake request received from the client.
|
||||
|
||||
Returns:
|
||||
``Sec-WebSocket-Accept``, ``Sec-WebSocket-Extensions``, and
|
||||
``Sec-WebSocket-Protocol`` headers for the handshake response.
|
||||
|
||||
Raises:
|
||||
InvalidHandshake: If the handshake request is invalid;
|
||||
then the server must return 400 Bad Request error.
|
||||
|
||||
"""
|
||||
headers = request.headers
|
||||
|
||||
connection: list[ConnectionOption] = sum(
|
||||
[parse_connection(value) for value in headers.get_all("Connection")], []
|
||||
)
|
||||
|
||||
if not any(value.lower() == "upgrade" for value in connection):
|
||||
raise InvalidUpgrade(
|
||||
"Connection", ", ".join(connection) if connection else None
|
||||
)
|
||||
|
||||
upgrade: list[UpgradeProtocol] = sum(
|
||||
[parse_upgrade(value) for value in headers.get_all("Upgrade")], []
|
||||
)
|
||||
|
||||
# For compatibility with non-strict implementations, ignore case when
|
||||
# checking the Upgrade header. The RFC always uses "websocket", except
|
||||
# in section 11.2. (IANA registration) where it uses "WebSocket".
|
||||
if not (len(upgrade) == 1 and upgrade[0].lower() == "websocket"):
|
||||
raise InvalidUpgrade("Upgrade", ", ".join(upgrade) if upgrade else None)
|
||||
|
||||
try:
|
||||
key = headers["Sec-WebSocket-Key"]
|
||||
except KeyError as exc:
|
||||
raise InvalidHeader("Sec-WebSocket-Key") from exc
|
||||
except MultipleValuesError as exc:
|
||||
raise InvalidHeader("Sec-WebSocket-Key", "multiple values") from exc
|
||||
|
||||
try:
|
||||
raw_key = base64.b64decode(key.encode(), validate=True)
|
||||
except binascii.Error as exc:
|
||||
raise InvalidHeaderValue("Sec-WebSocket-Key", key) from exc
|
||||
if len(raw_key) != 16:
|
||||
raise InvalidHeaderValue("Sec-WebSocket-Key", key)
|
||||
|
||||
try:
|
||||
version = headers["Sec-WebSocket-Version"]
|
||||
except KeyError as exc:
|
||||
raise InvalidHeader("Sec-WebSocket-Version") from exc
|
||||
except MultipleValuesError as exc:
|
||||
raise InvalidHeader("Sec-WebSocket-Version", "multiple values") from exc
|
||||
|
||||
if version != "13":
|
||||
raise InvalidHeaderValue("Sec-WebSocket-Version", version)
|
||||
|
||||
accept_header = accept_key(key)
|
||||
|
||||
self.origin = self.process_origin(headers)
|
||||
|
||||
extensions_header, self.extensions = self.process_extensions(headers)
|
||||
|
||||
protocol_header = self.subprotocol = self.process_subprotocol(headers)
|
||||
|
||||
return (
|
||||
accept_header,
|
||||
extensions_header,
|
||||
protocol_header,
|
||||
)
|
||||
|
||||
def process_origin(self, headers: Headers) -> Origin | None:
|
||||
"""
|
||||
Handle the Origin HTTP request header.
|
||||
|
||||
Args:
|
||||
headers: WebSocket handshake request headers.
|
||||
|
||||
Returns:
|
||||
origin, if it is acceptable.
|
||||
|
||||
Raises:
|
||||
InvalidHandshake: If the Origin header is invalid.
|
||||
InvalidOrigin: If the origin isn't acceptable.
|
||||
|
||||
"""
|
||||
# "The user agent MUST NOT include more than one Origin header field"
|
||||
# per https://datatracker.ietf.org/doc/html/rfc6454#section-7.3.
|
||||
try:
|
||||
origin = headers.get("Origin")
|
||||
except MultipleValuesError as exc:
|
||||
raise InvalidHeader("Origin", "multiple values") from exc
|
||||
if origin is not None:
|
||||
origin = cast(Origin, origin)
|
||||
if self.origins is not None:
|
||||
if origin not in self.origins:
|
||||
raise InvalidOrigin(origin)
|
||||
return origin
|
||||
|
||||
def process_extensions(
|
||||
self,
|
||||
headers: Headers,
|
||||
) -> tuple[str | None, list[Extension]]:
|
||||
"""
|
||||
Handle the Sec-WebSocket-Extensions HTTP request header.
|
||||
|
||||
Accept or reject each extension proposed in the client request.
|
||||
Negotiate parameters for accepted extensions.
|
||||
|
||||
Per :rfc:`6455`, negotiation rules are defined by the specification of
|
||||
each extension.
|
||||
|
||||
To provide this level of flexibility, for each extension proposed by
|
||||
the client, we check for a match with each extension available in the
|
||||
server configuration. If no match is found, the extension is ignored.
|
||||
|
||||
If several variants of the same extension are proposed by the client,
|
||||
it may be accepted several times, which won't make sense in general.
|
||||
Extensions must implement their own requirements. For this purpose,
|
||||
the list of previously accepted extensions is provided.
|
||||
|
||||
This process doesn't allow the server to reorder extensions. It can
|
||||
only select a subset of the extensions proposed by the client.
|
||||
|
||||
Other requirements, for example related to mandatory extensions or the
|
||||
order of extensions, may be implemented by overriding this method.
|
||||
|
||||
Args:
|
||||
headers: WebSocket handshake request headers.
|
||||
|
||||
Returns:
|
||||
``Sec-WebSocket-Extensions`` HTTP response header and list of
|
||||
accepted extensions.
|
||||
|
||||
Raises:
|
||||
InvalidHandshake: If the Sec-WebSocket-Extensions header is invalid.
|
||||
|
||||
"""
|
||||
response_header_value: str | None = None
|
||||
|
||||
extension_headers: list[ExtensionHeader] = []
|
||||
accepted_extensions: list[Extension] = []
|
||||
|
||||
header_values = headers.get_all("Sec-WebSocket-Extensions")
|
||||
|
||||
if header_values and self.available_extensions:
|
||||
parsed_header_values: list[ExtensionHeader] = sum(
|
||||
[parse_extension(header_value) for header_value in header_values], []
|
||||
)
|
||||
|
||||
for name, request_params in parsed_header_values:
|
||||
for ext_factory in self.available_extensions:
|
||||
# Skip non-matching extensions based on their name.
|
||||
if ext_factory.name != name:
|
||||
continue
|
||||
|
||||
# Skip non-matching extensions based on their params.
|
||||
try:
|
||||
response_params, extension = ext_factory.process_request_params(
|
||||
request_params, accepted_extensions
|
||||
)
|
||||
except NegotiationError:
|
||||
continue
|
||||
|
||||
# Add matching extension to the final list.
|
||||
extension_headers.append((name, response_params))
|
||||
accepted_extensions.append(extension)
|
||||
|
||||
# Break out of the loop once we have a match.
|
||||
break
|
||||
|
||||
# If we didn't break from the loop, no extension in our list
|
||||
# matched what the client sent. The extension is declined.
|
||||
|
||||
# Serialize extension header.
|
||||
if extension_headers:
|
||||
response_header_value = build_extension(extension_headers)
|
||||
|
||||
return response_header_value, accepted_extensions
|
||||
|
||||
def process_subprotocol(self, headers: Headers) -> Subprotocol | None:
|
||||
"""
|
||||
Handle the Sec-WebSocket-Protocol HTTP request header.
|
||||
|
||||
Args:
|
||||
headers: WebSocket handshake request headers.
|
||||
|
||||
Returns:
|
||||
Subprotocol, if one was selected; this is also the value of the
|
||||
``Sec-WebSocket-Protocol`` response header.
|
||||
|
||||
Raises:
|
||||
InvalidHandshake: If the Sec-WebSocket-Subprotocol header is invalid.
|
||||
|
||||
"""
|
||||
subprotocols: Sequence[Subprotocol] = sum(
|
||||
[
|
||||
parse_subprotocol(header_value)
|
||||
for header_value in headers.get_all("Sec-WebSocket-Protocol")
|
||||
],
|
||||
[],
|
||||
)
|
||||
|
||||
return self.select_subprotocol(subprotocols)
|
||||
|
||||
def select_subprotocol(
|
||||
self,
|
||||
subprotocols: Sequence[Subprotocol],
|
||||
) -> Subprotocol | None:
|
||||
"""
|
||||
Pick a subprotocol among those offered by the client.
|
||||
|
||||
If several subprotocols are supported by both the client and the server,
|
||||
pick the first one in the list declared the server.
|
||||
|
||||
If the server doesn't support any subprotocols, continue without a
|
||||
subprotocol, regardless of what the client offers.
|
||||
|
||||
If the server supports at least one subprotocol and the client doesn't
|
||||
offer any, abort the handshake with an HTTP 400 error.
|
||||
|
||||
You provide a ``select_subprotocol`` argument to :class:`ServerProtocol`
|
||||
to override this logic. For example, you could accept the connection
|
||||
even if client doesn't offer a subprotocol, rather than reject it.
|
||||
|
||||
Here's how to negotiate the ``chat`` subprotocol if the client supports
|
||||
it and continue without a subprotocol otherwise::
|
||||
|
||||
def select_subprotocol(protocol, subprotocols):
|
||||
if "chat" in subprotocols:
|
||||
return "chat"
|
||||
|
||||
Args:
|
||||
subprotocols: List of subprotocols offered by the client.
|
||||
|
||||
Returns:
|
||||
Selected subprotocol, if a common subprotocol was found.
|
||||
|
||||
:obj:`None` to continue without a subprotocol.
|
||||
|
||||
Raises:
|
||||
NegotiationError: Custom implementations may raise this exception
|
||||
to abort the handshake with an HTTP 400 error.
|
||||
|
||||
"""
|
||||
# Server doesn't offer any subprotocols.
|
||||
if not self.available_subprotocols: # None or empty list
|
||||
return None
|
||||
|
||||
# Server offers at least one subprotocol but client doesn't offer any.
|
||||
if not subprotocols:
|
||||
raise NegotiationError("missing subprotocol")
|
||||
|
||||
# Server and client both offer subprotocols. Look for a shared one.
|
||||
proposed_subprotocols = set(subprotocols)
|
||||
for subprotocol in self.available_subprotocols:
|
||||
if subprotocol in proposed_subprotocols:
|
||||
return subprotocol
|
||||
|
||||
# No common subprotocol was found.
|
||||
raise NegotiationError(
|
||||
"invalid subprotocol; expected one of "
|
||||
+ ", ".join(self.available_subprotocols)
|
||||
)
|
||||
|
||||
def reject(self, status: StatusLike, text: str) -> Response:
|
||||
"""
|
||||
Create a handshake response to reject the connection.
|
||||
|
||||
A short plain text response is the best fallback when failing to
|
||||
establish a WebSocket connection.
|
||||
|
||||
You must send the handshake response with :meth:`send_response`.
|
||||
|
||||
You may modify the response before sending it, for example by changing
|
||||
HTTP headers.
|
||||
|
||||
Args:
|
||||
status: HTTP status code.
|
||||
text: HTTP response body; it will be encoded to UTF-8.
|
||||
|
||||
Returns:
|
||||
HTTP response to send to the client.
|
||||
|
||||
"""
|
||||
# If a user passes an int instead of a HTTPStatus, fix it automatically.
|
||||
status = http.HTTPStatus(status)
|
||||
body = text.encode()
|
||||
headers = Headers(
|
||||
[
|
||||
("Date", email.utils.formatdate(usegmt=True)),
|
||||
("Connection", "close"),
|
||||
("Content-Length", str(len(body))),
|
||||
("Content-Type", "text/plain; charset=utf-8"),
|
||||
]
|
||||
)
|
||||
return Response(status.value, status.phrase, headers, body)
|
||||
|
||||
def send_response(self, response: Response) -> None:
|
||||
"""
|
||||
Send a handshake response to the client.
|
||||
|
||||
Args:
|
||||
response: WebSocket handshake response event to send.
|
||||
|
||||
"""
|
||||
if self.debug:
|
||||
code, phrase = response.status_code, response.reason_phrase
|
||||
self.logger.debug("> HTTP/1.1 %d %s", code, phrase)
|
||||
for key, value in response.headers.raw_items():
|
||||
self.logger.debug("> %s: %s", key, value)
|
||||
if response.body is not None:
|
||||
self.logger.debug("> [body] (%d bytes)", len(response.body))
|
||||
|
||||
self.writes.append(response.serialize())
|
||||
|
||||
if response.status_code == 101:
|
||||
assert self.state is CONNECTING
|
||||
self.state = OPEN
|
||||
self.logger.info("connection open")
|
||||
|
||||
else:
|
||||
# handshake_exc may be already set if accept() encountered an error.
|
||||
# If the connection isn't open, set handshake_exc to guarantee that
|
||||
# handshake_exc is None if and only if opening handshake succeeded.
|
||||
if self.handshake_exc is None:
|
||||
self.handshake_exc = InvalidStatus(response)
|
||||
self.logger.info(
|
||||
"connection rejected (%d %s)",
|
||||
response.status_code,
|
||||
response.reason_phrase,
|
||||
)
|
||||
|
||||
self.send_eof()
|
||||
self.parser = self.discard()
|
||||
next(self.parser) # start coroutine
|
||||
|
||||
def parse(self) -> Generator[None, None, None]:
|
||||
if self.state is CONNECTING:
|
||||
try:
|
||||
request = yield from Request.parse(
|
||||
self.reader.read_line,
|
||||
)
|
||||
except Exception as exc:
|
||||
self.handshake_exc = exc
|
||||
self.send_eof()
|
||||
self.parser = self.discard()
|
||||
next(self.parser) # start coroutine
|
||||
yield
|
||||
|
||||
if self.debug:
|
||||
self.logger.debug("< GET %s HTTP/1.1", request.path)
|
||||
for key, value in request.headers.raw_items():
|
||||
self.logger.debug("< %s: %s", key, value)
|
||||
|
||||
self.events.append(request)
|
||||
|
||||
yield from super().parse()
|
||||
|
||||
|
||||
class ServerConnection(ServerProtocol):
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
warnings.warn( # deprecated in 11.0 - 2023-04-02
|
||||
"ServerConnection was renamed to ServerProtocol",
|
||||
DeprecationWarning,
|
||||
)
|
||||
super().__init__(*args, **kwargs)
|
@ -1,222 +0,0 @@
|
||||
/* C implementation of performance sensitive functions. */
|
||||
|
||||
#define PY_SSIZE_T_CLEAN
|
||||
#include <Python.h>
|
||||
#include <stdint.h> /* uint8_t, uint32_t, uint64_t */
|
||||
|
||||
#if __ARM_NEON
|
||||
#include <arm_neon.h>
|
||||
#elif __SSE2__
|
||||
#include <emmintrin.h>
|
||||
#endif
|
||||
|
||||
static const Py_ssize_t MASK_LEN = 4;
|
||||
|
||||
/* Similar to PyBytes_AsStringAndSize, but accepts more types */
|
||||
|
||||
static int
|
||||
_PyBytesLike_AsStringAndSize(PyObject *obj, PyObject **tmp, char **buffer, Py_ssize_t *length)
|
||||
{
|
||||
// This supports bytes, bytearrays, and memoryview objects,
|
||||
// which are common data structures for handling byte streams.
|
||||
// If *tmp isn't NULL, the caller gets a new reference.
|
||||
if (PyBytes_Check(obj))
|
||||
{
|
||||
*tmp = NULL;
|
||||
*buffer = PyBytes_AS_STRING(obj);
|
||||
*length = PyBytes_GET_SIZE(obj);
|
||||
}
|
||||
else if (PyByteArray_Check(obj))
|
||||
{
|
||||
*tmp = NULL;
|
||||
*buffer = PyByteArray_AS_STRING(obj);
|
||||
*length = PyByteArray_GET_SIZE(obj);
|
||||
}
|
||||
else if (PyMemoryView_Check(obj))
|
||||
{
|
||||
*tmp = PyMemoryView_GetContiguous(obj, PyBUF_READ, 'C');
|
||||
if (*tmp == NULL)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
Py_buffer *mv_buf;
|
||||
mv_buf = PyMemoryView_GET_BUFFER(*tmp);
|
||||
*buffer = mv_buf->buf;
|
||||
*length = mv_buf->len;
|
||||
}
|
||||
else
|
||||
{
|
||||
PyErr_Format(
|
||||
PyExc_TypeError,
|
||||
"expected a bytes-like object, %.200s found",
|
||||
Py_TYPE(obj)->tp_name);
|
||||
return -1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/* C implementation of websockets.utils.apply_mask */
|
||||
|
||||
static PyObject *
|
||||
apply_mask(PyObject *self, PyObject *args, PyObject *kwds)
|
||||
{
|
||||
|
||||
// In order to support various bytes-like types, accept any Python object.
|
||||
|
||||
static char *kwlist[] = {"data", "mask", NULL};
|
||||
PyObject *input_obj;
|
||||
PyObject *mask_obj;
|
||||
|
||||
// A pointer to a char * + length will be extracted from the data and mask
|
||||
// arguments, possibly via a Py_buffer.
|
||||
|
||||
PyObject *input_tmp = NULL;
|
||||
char *input;
|
||||
Py_ssize_t input_len;
|
||||
PyObject *mask_tmp = NULL;
|
||||
char *mask;
|
||||
Py_ssize_t mask_len;
|
||||
|
||||
// Initialize a PyBytesObject then get a pointer to the underlying char *
|
||||
// in order to avoid an extra memory copy in PyBytes_FromStringAndSize.
|
||||
|
||||
PyObject *result = NULL;
|
||||
char *output;
|
||||
|
||||
// Other variables.
|
||||
|
||||
Py_ssize_t i = 0;
|
||||
|
||||
// Parse inputs.
|
||||
|
||||
if (!PyArg_ParseTupleAndKeywords(
|
||||
args, kwds, "OO", kwlist, &input_obj, &mask_obj))
|
||||
{
|
||||
goto exit;
|
||||
}
|
||||
|
||||
if (_PyBytesLike_AsStringAndSize(input_obj, &input_tmp, &input, &input_len) == -1)
|
||||
{
|
||||
goto exit;
|
||||
}
|
||||
|
||||
if (_PyBytesLike_AsStringAndSize(mask_obj, &mask_tmp, &mask, &mask_len) == -1)
|
||||
{
|
||||
goto exit;
|
||||
}
|
||||
|
||||
if (mask_len != MASK_LEN)
|
||||
{
|
||||
PyErr_SetString(PyExc_ValueError, "mask must contain 4 bytes");
|
||||
goto exit;
|
||||
}
|
||||
|
||||
// Create output.
|
||||
|
||||
result = PyBytes_FromStringAndSize(NULL, input_len);
|
||||
if (result == NULL)
|
||||
{
|
||||
goto exit;
|
||||
}
|
||||
|
||||
// Since we just created result, we don't need error checks.
|
||||
output = PyBytes_AS_STRING(result);
|
||||
|
||||
// Perform the masking operation.
|
||||
|
||||
// Apparently GCC cannot figure out the following optimizations by itself.
|
||||
|
||||
// We need a new scope for MSVC 2010 (non C99 friendly)
|
||||
{
|
||||
#if __ARM_NEON
|
||||
|
||||
// With NEON support, XOR by blocks of 16 bytes = 128 bits.
|
||||
|
||||
Py_ssize_t input_len_128 = input_len & ~15;
|
||||
uint8x16_t mask_128 = vreinterpretq_u8_u32(vdupq_n_u32(*(uint32_t *)mask));
|
||||
|
||||
for (; i < input_len_128; i += 16)
|
||||
{
|
||||
uint8x16_t in_128 = vld1q_u8((uint8_t *)(input + i));
|
||||
uint8x16_t out_128 = veorq_u8(in_128, mask_128);
|
||||
vst1q_u8((uint8_t *)(output + i), out_128);
|
||||
}
|
||||
|
||||
#elif __SSE2__
|
||||
|
||||
// With SSE2 support, XOR by blocks of 16 bytes = 128 bits.
|
||||
|
||||
// Since we cannot control the 16-bytes alignment of input and output
|
||||
// buffers, we rely on loadu/storeu rather than load/store.
|
||||
|
||||
Py_ssize_t input_len_128 = input_len & ~15;
|
||||
__m128i mask_128 = _mm_set1_epi32(*(uint32_t *)mask);
|
||||
|
||||
for (; i < input_len_128; i += 16)
|
||||
{
|
||||
__m128i in_128 = _mm_loadu_si128((__m128i *)(input + i));
|
||||
__m128i out_128 = _mm_xor_si128(in_128, mask_128);
|
||||
_mm_storeu_si128((__m128i *)(output + i), out_128);
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
// Without SSE2 support, XOR by blocks of 8 bytes = 64 bits.
|
||||
|
||||
// We assume the memory allocator aligns everything on 8 bytes boundaries.
|
||||
|
||||
Py_ssize_t input_len_64 = input_len & ~7;
|
||||
uint32_t mask_32 = *(uint32_t *)mask;
|
||||
uint64_t mask_64 = ((uint64_t)mask_32 << 32) | (uint64_t)mask_32;
|
||||
|
||||
for (; i < input_len_64; i += 8)
|
||||
{
|
||||
*(uint64_t *)(output + i) = *(uint64_t *)(input + i) ^ mask_64;
|
||||
}
|
||||
|
||||
#endif
|
||||
}
|
||||
|
||||
// XOR the remainder of the input byte by byte.
|
||||
|
||||
for (; i < input_len; i++)
|
||||
{
|
||||
output[i] = input[i] ^ mask[i & (MASK_LEN - 1)];
|
||||
}
|
||||
|
||||
exit:
|
||||
Py_XDECREF(input_tmp);
|
||||
Py_XDECREF(mask_tmp);
|
||||
return result;
|
||||
|
||||
}
|
||||
|
||||
static PyMethodDef speedups_methods[] = {
|
||||
{
|
||||
"apply_mask",
|
||||
(PyCFunction)apply_mask,
|
||||
METH_VARARGS | METH_KEYWORDS,
|
||||
"Apply masking to the data of a WebSocket message.",
|
||||
},
|
||||
{NULL, NULL, 0, NULL}, /* Sentinel */
|
||||
};
|
||||
|
||||
static struct PyModuleDef speedups_module = {
|
||||
PyModuleDef_HEAD_INIT,
|
||||
"websocket.speedups", /* m_name */
|
||||
"C implementation of performance sensitive functions.",
|
||||
/* m_doc */
|
||||
-1, /* m_size */
|
||||
speedups_methods, /* m_methods */
|
||||
NULL,
|
||||
NULL,
|
||||
NULL,
|
||||
NULL
|
||||
};
|
||||
|
||||
PyMODINIT_FUNC
|
||||
PyInit_speedups(void)
|
||||
{
|
||||
return PyModule_Create(&speedups_module);
|
||||
}
|
@ -1 +0,0 @@
|
||||
def apply_mask(data: bytes, mask: bytes) -> bytes: ...
|
@ -1,151 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Generator
|
||||
|
||||
|
||||
class StreamReader:
|
||||
"""
|
||||
Generator-based stream reader.
|
||||
|
||||
This class doesn't support concurrent calls to :meth:`read_line`,
|
||||
:meth:`read_exact`, or :meth:`read_to_eof`. Make sure calls are
|
||||
serialized.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.buffer = bytearray()
|
||||
self.eof = False
|
||||
|
||||
def read_line(self, m: int) -> Generator[None, None, bytes]:
|
||||
"""
|
||||
Read a LF-terminated line from the stream.
|
||||
|
||||
This is a generator-based coroutine.
|
||||
|
||||
The return value includes the LF character.
|
||||
|
||||
Args:
|
||||
m: Maximum number bytes to read; this is a security limit.
|
||||
|
||||
Raises:
|
||||
EOFError: If the stream ends without a LF.
|
||||
RuntimeError: If the stream ends in more than ``m`` bytes.
|
||||
|
||||
"""
|
||||
n = 0 # number of bytes to read
|
||||
p = 0 # number of bytes without a newline
|
||||
while True:
|
||||
n = self.buffer.find(b"\n", p) + 1
|
||||
if n > 0:
|
||||
break
|
||||
p = len(self.buffer)
|
||||
if p > m:
|
||||
raise RuntimeError(f"read {p} bytes, expected no more than {m} bytes")
|
||||
if self.eof:
|
||||
raise EOFError(f"stream ends after {p} bytes, before end of line")
|
||||
yield
|
||||
if n > m:
|
||||
raise RuntimeError(f"read {n} bytes, expected no more than {m} bytes")
|
||||
r = self.buffer[:n]
|
||||
del self.buffer[:n]
|
||||
return r
|
||||
|
||||
def read_exact(self, n: int) -> Generator[None, None, bytes]:
|
||||
"""
|
||||
Read a given number of bytes from the stream.
|
||||
|
||||
This is a generator-based coroutine.
|
||||
|
||||
Args:
|
||||
n: How many bytes to read.
|
||||
|
||||
Raises:
|
||||
EOFError: If the stream ends in less than ``n`` bytes.
|
||||
|
||||
"""
|
||||
assert n >= 0
|
||||
while len(self.buffer) < n:
|
||||
if self.eof:
|
||||
p = len(self.buffer)
|
||||
raise EOFError(f"stream ends after {p} bytes, expected {n} bytes")
|
||||
yield
|
||||
r = self.buffer[:n]
|
||||
del self.buffer[:n]
|
||||
return r
|
||||
|
||||
def read_to_eof(self, m: int) -> Generator[None, None, bytes]:
|
||||
"""
|
||||
Read all bytes from the stream.
|
||||
|
||||
This is a generator-based coroutine.
|
||||
|
||||
Args:
|
||||
m: Maximum number bytes to read; this is a security limit.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the stream ends in more than ``m`` bytes.
|
||||
|
||||
"""
|
||||
while not self.eof:
|
||||
p = len(self.buffer)
|
||||
if p > m:
|
||||
raise RuntimeError(f"read {p} bytes, expected no more than {m} bytes")
|
||||
yield
|
||||
r = self.buffer[:]
|
||||
del self.buffer[:]
|
||||
return r
|
||||
|
||||
def at_eof(self) -> Generator[None, None, bool]:
|
||||
"""
|
||||
Tell whether the stream has ended and all data was read.
|
||||
|
||||
This is a generator-based coroutine.
|
||||
|
||||
"""
|
||||
while True:
|
||||
if self.buffer:
|
||||
return False
|
||||
if self.eof:
|
||||
return True
|
||||
# When all data was read but the stream hasn't ended, we can't
|
||||
# tell if until either feed_data() or feed_eof() is called.
|
||||
yield
|
||||
|
||||
def feed_data(self, data: bytes) -> None:
|
||||
"""
|
||||
Write data to the stream.
|
||||
|
||||
:meth:`feed_data` cannot be called after :meth:`feed_eof`.
|
||||
|
||||
Args:
|
||||
data: Data to write.
|
||||
|
||||
Raises:
|
||||
EOFError: If the stream has ended.
|
||||
|
||||
"""
|
||||
if self.eof:
|
||||
raise EOFError("stream ended")
|
||||
self.buffer += data
|
||||
|
||||
def feed_eof(self) -> None:
|
||||
"""
|
||||
End the stream.
|
||||
|
||||
:meth:`feed_eof` cannot be called more than once.
|
||||
|
||||
Raises:
|
||||
EOFError: If the stream has ended.
|
||||
|
||||
"""
|
||||
if self.eof:
|
||||
raise EOFError("stream ended")
|
||||
self.eof = True
|
||||
|
||||
def discard(self) -> None:
|
||||
"""
|
||||
Discard all buffered data, but don't end the stream.
|
||||
|
||||
"""
|
||||
del self.buffer[:]
|
@ -1,336 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import socket
|
||||
import ssl as ssl_module
|
||||
import threading
|
||||
import warnings
|
||||
from typing import Any, Sequence
|
||||
|
||||
from ..client import ClientProtocol
|
||||
from ..datastructures import HeadersLike
|
||||
from ..extensions.base import ClientExtensionFactory
|
||||
from ..extensions.permessage_deflate import enable_client_permessage_deflate
|
||||
from ..headers import validate_subprotocols
|
||||
from ..http11 import USER_AGENT, Response
|
||||
from ..protocol import CONNECTING, Event
|
||||
from ..typing import LoggerLike, Origin, Subprotocol
|
||||
from ..uri import parse_uri
|
||||
from .connection import Connection
|
||||
from .utils import Deadline
|
||||
|
||||
|
||||
__all__ = ["connect", "unix_connect", "ClientConnection"]
|
||||
|
||||
|
||||
class ClientConnection(Connection):
|
||||
"""
|
||||
:mod:`threading` implementation of a WebSocket client connection.
|
||||
|
||||
:class:`ClientConnection` provides :meth:`recv` and :meth:`send` methods for
|
||||
receiving and sending messages.
|
||||
|
||||
It supports iteration to receive messages::
|
||||
|
||||
for message in websocket:
|
||||
process(message)
|
||||
|
||||
The iterator exits normally when the connection is closed with close code
|
||||
1000 (OK) or 1001 (going away) or without a close code. It raises a
|
||||
:exc:`~websockets.exceptions.ConnectionClosedError` when the connection is
|
||||
closed with any other code.
|
||||
|
||||
Args:
|
||||
socket: Socket connected to a WebSocket server.
|
||||
protocol: Sans-I/O connection.
|
||||
close_timeout: Timeout for closing the connection in seconds.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
socket: socket.socket,
|
||||
protocol: ClientProtocol,
|
||||
*,
|
||||
close_timeout: float | None = 10,
|
||||
) -> None:
|
||||
self.protocol: ClientProtocol
|
||||
self.response_rcvd = threading.Event()
|
||||
super().__init__(
|
||||
socket,
|
||||
protocol,
|
||||
close_timeout=close_timeout,
|
||||
)
|
||||
|
||||
def handshake(
|
||||
self,
|
||||
additional_headers: HeadersLike | None = None,
|
||||
user_agent_header: str | None = USER_AGENT,
|
||||
timeout: float | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Perform the opening handshake.
|
||||
|
||||
"""
|
||||
with self.send_context(expected_state=CONNECTING):
|
||||
self.request = self.protocol.connect()
|
||||
if additional_headers is not None:
|
||||
self.request.headers.update(additional_headers)
|
||||
if user_agent_header is not None:
|
||||
self.request.headers["User-Agent"] = user_agent_header
|
||||
self.protocol.send_request(self.request)
|
||||
|
||||
if not self.response_rcvd.wait(timeout):
|
||||
raise TimeoutError("timed out during handshake")
|
||||
|
||||
# self.protocol.handshake_exc is always set when the connection is lost
|
||||
# before receiving a response, when the response cannot be parsed, or
|
||||
# when the response fails the handshake.
|
||||
|
||||
if self.protocol.handshake_exc is not None:
|
||||
raise self.protocol.handshake_exc
|
||||
|
||||
def process_event(self, event: Event) -> None:
|
||||
"""
|
||||
Process one incoming event.
|
||||
|
||||
"""
|
||||
# First event - handshake response.
|
||||
if self.response is None:
|
||||
assert isinstance(event, Response)
|
||||
self.response = event
|
||||
self.response_rcvd.set()
|
||||
# Later events - frames.
|
||||
else:
|
||||
super().process_event(event)
|
||||
|
||||
def recv_events(self) -> None:
|
||||
"""
|
||||
Read incoming data from the socket and process events.
|
||||
|
||||
"""
|
||||
try:
|
||||
super().recv_events()
|
||||
finally:
|
||||
# If the connection is closed during the handshake, unblock it.
|
||||
self.response_rcvd.set()
|
||||
|
||||
|
||||
def connect(
|
||||
uri: str,
|
||||
*,
|
||||
# TCP/TLS
|
||||
sock: socket.socket | None = None,
|
||||
ssl: ssl_module.SSLContext | None = None,
|
||||
server_hostname: str | None = None,
|
||||
# WebSocket
|
||||
origin: Origin | None = None,
|
||||
extensions: Sequence[ClientExtensionFactory] | None = None,
|
||||
subprotocols: Sequence[Subprotocol] | None = None,
|
||||
additional_headers: HeadersLike | None = None,
|
||||
user_agent_header: str | None = USER_AGENT,
|
||||
compression: str | None = "deflate",
|
||||
# Timeouts
|
||||
open_timeout: float | None = 10,
|
||||
close_timeout: float | None = 10,
|
||||
# Limits
|
||||
max_size: int | None = 2**20,
|
||||
# Logging
|
||||
logger: LoggerLike | None = None,
|
||||
# Escape hatch for advanced customization
|
||||
create_connection: type[ClientConnection] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ClientConnection:
|
||||
"""
|
||||
Connect to the WebSocket server at ``uri``.
|
||||
|
||||
This function returns a :class:`ClientConnection` instance, which you can
|
||||
use to send and receive messages.
|
||||
|
||||
:func:`connect` may be used as a context manager::
|
||||
|
||||
from websockets.sync.client import connect
|
||||
|
||||
with connect(...) as websocket:
|
||||
...
|
||||
|
||||
The connection is closed automatically when exiting the context.
|
||||
|
||||
Args:
|
||||
uri: URI of the WebSocket server.
|
||||
sock: Preexisting TCP socket. ``sock`` overrides the host and port
|
||||
from ``uri``. You may call :func:`socket.create_connection` to
|
||||
create a suitable TCP socket.
|
||||
ssl: Configuration for enabling TLS on the connection.
|
||||
server_hostname: Host name for the TLS handshake. ``server_hostname``
|
||||
overrides the host name from ``uri``.
|
||||
origin: Value of the ``Origin`` header, for servers that require it.
|
||||
extensions: List of supported extensions, in order in which they
|
||||
should be negotiated and run.
|
||||
subprotocols: List of supported subprotocols, in order of decreasing
|
||||
preference.
|
||||
additional_headers (HeadersLike | None): Arbitrary HTTP headers to add
|
||||
to the handshake request.
|
||||
user_agent_header: Value of the ``User-Agent`` request header.
|
||||
It defaults to ``"Python/x.y.z websockets/X.Y"``.
|
||||
Setting it to :obj:`None` removes the header.
|
||||
compression: The "permessage-deflate" extension is enabled by default.
|
||||
Set ``compression`` to :obj:`None` to disable it. See the
|
||||
:doc:`compression guide <../../topics/compression>` for details.
|
||||
open_timeout: Timeout for opening the connection in seconds.
|
||||
:obj:`None` disables the timeout.
|
||||
close_timeout: Timeout for closing the connection in seconds.
|
||||
:obj:`None` disables the timeout.
|
||||
max_size: Maximum size of incoming messages in bytes.
|
||||
:obj:`None` disables the limit.
|
||||
logger: Logger for this client.
|
||||
It defaults to ``logging.getLogger("websockets.client")``.
|
||||
See the :doc:`logging guide <../../topics/logging>` for details.
|
||||
create_connection: Factory for the :class:`ClientConnection` managing
|
||||
the connection. Set it to a wrapper or a subclass to customize
|
||||
connection handling.
|
||||
|
||||
Any other keyword arguments are passed to :func:`~socket.create_connection`.
|
||||
|
||||
Raises:
|
||||
InvalidURI: If ``uri`` isn't a valid WebSocket URI.
|
||||
OSError: If the TCP connection fails.
|
||||
InvalidHandshake: If the opening handshake fails.
|
||||
TimeoutError: If the opening handshake times out.
|
||||
|
||||
"""
|
||||
|
||||
# Process parameters
|
||||
|
||||
# Backwards compatibility: ssl used to be called ssl_context.
|
||||
if ssl is None and "ssl_context" in kwargs:
|
||||
ssl = kwargs.pop("ssl_context")
|
||||
warnings.warn( # deprecated in 13.0 - 2024-08-20
|
||||
"ssl_context was renamed to ssl",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
wsuri = parse_uri(uri)
|
||||
if not wsuri.secure and ssl is not None:
|
||||
raise TypeError("ssl argument is incompatible with a ws:// URI")
|
||||
|
||||
# Private APIs for unix_connect()
|
||||
unix: bool = kwargs.pop("unix", False)
|
||||
path: str | None = kwargs.pop("path", None)
|
||||
|
||||
if unix:
|
||||
if path is None and sock is None:
|
||||
raise TypeError("missing path argument")
|
||||
elif path is not None and sock is not None:
|
||||
raise TypeError("path and sock arguments are incompatible")
|
||||
|
||||
if subprotocols is not None:
|
||||
validate_subprotocols(subprotocols)
|
||||
|
||||
if compression == "deflate":
|
||||
extensions = enable_client_permessage_deflate(extensions)
|
||||
elif compression is not None:
|
||||
raise ValueError(f"unsupported compression: {compression}")
|
||||
|
||||
# Calculate timeouts on the TCP, TLS, and WebSocket handshakes.
|
||||
# The TCP and TLS timeouts must be set on the socket, then removed
|
||||
# to avoid conflicting with the WebSocket timeout in handshake().
|
||||
deadline = Deadline(open_timeout)
|
||||
|
||||
if create_connection is None:
|
||||
create_connection = ClientConnection
|
||||
|
||||
try:
|
||||
# Connect socket
|
||||
|
||||
if sock is None:
|
||||
if unix:
|
||||
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
sock.settimeout(deadline.timeout())
|
||||
assert path is not None # mypy cannot figure this out
|
||||
sock.connect(path)
|
||||
else:
|
||||
kwargs.setdefault("timeout", deadline.timeout())
|
||||
sock = socket.create_connection((wsuri.host, wsuri.port), **kwargs)
|
||||
sock.settimeout(None)
|
||||
|
||||
# Disable Nagle algorithm
|
||||
|
||||
if not unix:
|
||||
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, True)
|
||||
|
||||
# Initialize TLS wrapper and perform TLS handshake
|
||||
|
||||
if wsuri.secure:
|
||||
if ssl is None:
|
||||
ssl = ssl_module.create_default_context()
|
||||
if server_hostname is None:
|
||||
server_hostname = wsuri.host
|
||||
sock.settimeout(deadline.timeout())
|
||||
sock = ssl.wrap_socket(sock, server_hostname=server_hostname)
|
||||
sock.settimeout(None)
|
||||
|
||||
# Initialize WebSocket protocol
|
||||
|
||||
protocol = ClientProtocol(
|
||||
wsuri,
|
||||
origin=origin,
|
||||
extensions=extensions,
|
||||
subprotocols=subprotocols,
|
||||
max_size=max_size,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
# Initialize WebSocket connection
|
||||
|
||||
connection = create_connection(
|
||||
sock,
|
||||
protocol,
|
||||
close_timeout=close_timeout,
|
||||
)
|
||||
except Exception:
|
||||
if sock is not None:
|
||||
sock.close()
|
||||
raise
|
||||
|
||||
try:
|
||||
connection.handshake(
|
||||
additional_headers,
|
||||
user_agent_header,
|
||||
deadline.timeout(),
|
||||
)
|
||||
except Exception:
|
||||
connection.close_socket()
|
||||
connection.recv_events_thread.join()
|
||||
raise
|
||||
|
||||
return connection
|
||||
|
||||
|
||||
def unix_connect(
|
||||
path: str | None = None,
|
||||
uri: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ClientConnection:
|
||||
"""
|
||||
Connect to a WebSocket server listening on a Unix socket.
|
||||
|
||||
This function accepts the same keyword arguments as :func:`connect`.
|
||||
|
||||
It's only available on Unix.
|
||||
|
||||
It's mainly useful for debugging servers listening on Unix sockets.
|
||||
|
||||
Args:
|
||||
path: File system path to the Unix socket.
|
||||
uri: URI of the WebSocket server. ``uri`` defaults to
|
||||
``ws://localhost/`` or, when a ``ssl`` is provided, to
|
||||
``wss://localhost/``.
|
||||
|
||||
"""
|
||||
if uri is None:
|
||||
# Backwards compatibility: ssl used to be called ssl_context.
|
||||
if kwargs.get("ssl") is None and kwargs.get("ssl_context") is None:
|
||||
uri = "ws://localhost/"
|
||||
else:
|
||||
uri = "wss://localhost/"
|
||||
return connect(uri=uri, unix=True, path=path, **kwargs)
|
@ -1,791 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import random
|
||||
import socket
|
||||
import struct
|
||||
import threading
|
||||
import uuid
|
||||
from types import TracebackType
|
||||
from typing import Any, Iterable, Iterator, Mapping
|
||||
|
||||
from ..exceptions import (
|
||||
ConcurrencyError,
|
||||
ConnectionClosed,
|
||||
ConnectionClosedOK,
|
||||
ProtocolError,
|
||||
)
|
||||
from ..frames import DATA_OPCODES, BytesLike, CloseCode, Frame, Opcode
|
||||
from ..http11 import Request, Response
|
||||
from ..protocol import CLOSED, OPEN, Event, Protocol, State
|
||||
from ..typing import Data, LoggerLike, Subprotocol
|
||||
from .messages import Assembler
|
||||
from .utils import Deadline
|
||||
|
||||
|
||||
__all__ = ["Connection"]
|
||||
|
||||
|
||||
class Connection:
|
||||
"""
|
||||
:mod:`threading` implementation of a WebSocket connection.
|
||||
|
||||
:class:`Connection` provides APIs shared between WebSocket servers and
|
||||
clients.
|
||||
|
||||
You shouldn't use it directly. Instead, use
|
||||
:class:`~websockets.sync.client.ClientConnection` or
|
||||
:class:`~websockets.sync.server.ServerConnection`.
|
||||
|
||||
"""
|
||||
|
||||
recv_bufsize = 65536
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
socket: socket.socket,
|
||||
protocol: Protocol,
|
||||
*,
|
||||
close_timeout: float | None = 10,
|
||||
) -> None:
|
||||
self.socket = socket
|
||||
self.protocol = protocol
|
||||
self.close_timeout = close_timeout
|
||||
|
||||
# Inject reference to this instance in the protocol's logger.
|
||||
self.protocol.logger = logging.LoggerAdapter(
|
||||
self.protocol.logger,
|
||||
{"websocket": self},
|
||||
)
|
||||
|
||||
# Copy attributes from the protocol for convenience.
|
||||
self.id: uuid.UUID = self.protocol.id
|
||||
"""Unique identifier of the connection. Useful in logs."""
|
||||
self.logger: LoggerLike = self.protocol.logger
|
||||
"""Logger for this connection."""
|
||||
self.debug = self.protocol.debug
|
||||
|
||||
# HTTP handshake request and response.
|
||||
self.request: Request | None = None
|
||||
"""Opening handshake request."""
|
||||
self.response: Response | None = None
|
||||
"""Opening handshake response."""
|
||||
|
||||
# Mutex serializing interactions with the protocol.
|
||||
self.protocol_mutex = threading.Lock()
|
||||
|
||||
# Assembler turning frames into messages and serializing reads.
|
||||
self.recv_messages = Assembler()
|
||||
|
||||
# Whether we are busy sending a fragmented message.
|
||||
self.send_in_progress = False
|
||||
|
||||
# Deadline for the closing handshake.
|
||||
self.close_deadline: Deadline | None = None
|
||||
|
||||
# Mapping of ping IDs to pong waiters, in chronological order.
|
||||
self.ping_waiters: dict[bytes, threading.Event] = {}
|
||||
|
||||
# Receiving events from the socket. This thread explicitly is marked as
|
||||
# to support creating a connection in a non-daemon thread then using it
|
||||
# in a daemon thread; this shouldn't block the intpreter from exiting.
|
||||
self.recv_events_thread = threading.Thread(
|
||||
target=self.recv_events,
|
||||
daemon=True,
|
||||
)
|
||||
self.recv_events_thread.start()
|
||||
|
||||
# Exception raised in recv_events, to be chained to ConnectionClosed
|
||||
# in the user thread in order to show why the TCP connection dropped.
|
||||
self.recv_exc: BaseException | None = None
|
||||
|
||||
# Public attributes
|
||||
|
||||
@property
|
||||
def local_address(self) -> Any:
|
||||
"""
|
||||
Local address of the connection.
|
||||
|
||||
For IPv4 connections, this is a ``(host, port)`` tuple.
|
||||
|
||||
The format of the address depends on the address family.
|
||||
See :meth:`~socket.socket.getsockname`.
|
||||
|
||||
"""
|
||||
return self.socket.getsockname()
|
||||
|
||||
@property
|
||||
def remote_address(self) -> Any:
|
||||
"""
|
||||
Remote address of the connection.
|
||||
|
||||
For IPv4 connections, this is a ``(host, port)`` tuple.
|
||||
|
||||
The format of the address depends on the address family.
|
||||
See :meth:`~socket.socket.getpeername`.
|
||||
|
||||
"""
|
||||
return self.socket.getpeername()
|
||||
|
||||
@property
|
||||
def subprotocol(self) -> Subprotocol | None:
|
||||
"""
|
||||
Subprotocol negotiated during the opening handshake.
|
||||
|
||||
:obj:`None` if no subprotocol was negotiated.
|
||||
|
||||
"""
|
||||
return self.protocol.subprotocol
|
||||
|
||||
# Public methods
|
||||
|
||||
def __enter__(self) -> Connection:
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
traceback: TracebackType | None,
|
||||
) -> None:
|
||||
if exc_type is None:
|
||||
self.close()
|
||||
else:
|
||||
self.close(CloseCode.INTERNAL_ERROR)
|
||||
|
||||
def __iter__(self) -> Iterator[Data]:
|
||||
"""
|
||||
Iterate on incoming messages.
|
||||
|
||||
The iterator calls :meth:`recv` and yields messages in an infinite loop.
|
||||
|
||||
It exits when the connection is closed normally. It raises a
|
||||
:exc:`~websockets.exceptions.ConnectionClosedError` exception after a
|
||||
protocol error or a network failure.
|
||||
|
||||
"""
|
||||
try:
|
||||
while True:
|
||||
yield self.recv()
|
||||
except ConnectionClosedOK:
|
||||
return
|
||||
|
||||
def recv(self, timeout: float | None = None) -> Data:
|
||||
"""
|
||||
Receive the next message.
|
||||
|
||||
When the connection is closed, :meth:`recv` raises
|
||||
:exc:`~websockets.exceptions.ConnectionClosed`. Specifically, it raises
|
||||
:exc:`~websockets.exceptions.ConnectionClosedOK` after a normal closure
|
||||
and :exc:`~websockets.exceptions.ConnectionClosedError` after a protocol
|
||||
error or a network failure. This is how you detect the end of the
|
||||
message stream.
|
||||
|
||||
If ``timeout`` is :obj:`None`, block until a message is received. If
|
||||
``timeout`` is set and no message is received within ``timeout``
|
||||
seconds, raise :exc:`TimeoutError`. Set ``timeout`` to ``0`` to check if
|
||||
a message was already received.
|
||||
|
||||
If the message is fragmented, wait until all fragments are received,
|
||||
reassemble them, and return the whole message.
|
||||
|
||||
Returns:
|
||||
A string (:class:`str`) for a Text_ frame or a bytestring
|
||||
(:class:`bytes`) for a Binary_ frame.
|
||||
|
||||
.. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6
|
||||
.. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6
|
||||
|
||||
Raises:
|
||||
ConnectionClosed: When the connection is closed.
|
||||
ConcurrencyError: If two threads call :meth:`recv` or
|
||||
:meth:`recv_streaming` concurrently.
|
||||
|
||||
"""
|
||||
try:
|
||||
return self.recv_messages.get(timeout)
|
||||
except EOFError:
|
||||
raise self.protocol.close_exc from self.recv_exc
|
||||
except ConcurrencyError:
|
||||
raise ConcurrencyError(
|
||||
"cannot call recv while another thread "
|
||||
"is already running recv or recv_streaming"
|
||||
) from None
|
||||
|
||||
def recv_streaming(self) -> Iterator[Data]:
|
||||
"""
|
||||
Receive the next message frame by frame.
|
||||
|
||||
If the message is fragmented, yield each fragment as it is received.
|
||||
The iterator must be fully consumed, or else the connection will become
|
||||
unusable.
|
||||
|
||||
:meth:`recv_streaming` raises the same exceptions as :meth:`recv`.
|
||||
|
||||
Returns:
|
||||
An iterator of strings (:class:`str`) for a Text_ frame or
|
||||
bytestrings (:class:`bytes`) for a Binary_ frame.
|
||||
|
||||
.. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6
|
||||
.. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6
|
||||
|
||||
Raises:
|
||||
ConnectionClosed: When the connection is closed.
|
||||
ConcurrencyError: If two threads call :meth:`recv` or
|
||||
:meth:`recv_streaming` concurrently.
|
||||
|
||||
"""
|
||||
try:
|
||||
for frame in self.recv_messages.get_iter():
|
||||
yield frame
|
||||
except EOFError:
|
||||
raise self.protocol.close_exc from self.recv_exc
|
||||
except ConcurrencyError:
|
||||
raise ConcurrencyError(
|
||||
"cannot call recv_streaming while another thread "
|
||||
"is already running recv or recv_streaming"
|
||||
) from None
|
||||
|
||||
def send(self, message: Data | Iterable[Data]) -> None:
|
||||
"""
|
||||
Send a message.
|
||||
|
||||
A string (:class:`str`) is sent as a Text_ frame. A bytestring or
|
||||
bytes-like object (:class:`bytes`, :class:`bytearray`, or
|
||||
:class:`memoryview`) is sent as a Binary_ frame.
|
||||
|
||||
.. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6
|
||||
.. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6
|
||||
|
||||
:meth:`send` also accepts an iterable of strings, bytestrings, or
|
||||
bytes-like objects to enable fragmentation_. Each item is treated as a
|
||||
message fragment and sent in its own frame. All items must be of the
|
||||
same type, or else :meth:`send` will raise a :exc:`TypeError` and the
|
||||
connection will be closed.
|
||||
|
||||
.. _fragmentation: https://datatracker.ietf.org/doc/html/rfc6455#section-5.4
|
||||
|
||||
:meth:`send` rejects dict-like objects because this is often an error.
|
||||
(If you really want to send the keys of a dict-like object as fragments,
|
||||
call its :meth:`~dict.keys` method and pass the result to :meth:`send`.)
|
||||
|
||||
When the connection is closed, :meth:`send` raises
|
||||
:exc:`~websockets.exceptions.ConnectionClosed`. Specifically, it
|
||||
raises :exc:`~websockets.exceptions.ConnectionClosedOK` after a normal
|
||||
connection closure and
|
||||
:exc:`~websockets.exceptions.ConnectionClosedError` after a protocol
|
||||
error or a network failure.
|
||||
|
||||
Args:
|
||||
message: Message to send.
|
||||
|
||||
Raises:
|
||||
ConnectionClosed: When the connection is closed.
|
||||
ConcurrencyError: If the connection is sending a fragmented message.
|
||||
TypeError: If ``message`` doesn't have a supported type.
|
||||
|
||||
"""
|
||||
# Unfragmented message -- this case must be handled first because
|
||||
# strings and bytes-like objects are iterable.
|
||||
|
||||
if isinstance(message, str):
|
||||
with self.send_context():
|
||||
if self.send_in_progress:
|
||||
raise ConcurrencyError(
|
||||
"cannot call send while another thread "
|
||||
"is already running send"
|
||||
)
|
||||
self.protocol.send_text(message.encode())
|
||||
|
||||
elif isinstance(message, BytesLike):
|
||||
with self.send_context():
|
||||
if self.send_in_progress:
|
||||
raise ConcurrencyError(
|
||||
"cannot call send while another thread "
|
||||
"is already running send"
|
||||
)
|
||||
self.protocol.send_binary(message)
|
||||
|
||||
# Catch a common mistake -- passing a dict to send().
|
||||
|
||||
elif isinstance(message, Mapping):
|
||||
raise TypeError("data is a dict-like object")
|
||||
|
||||
# Fragmented message -- regular iterator.
|
||||
|
||||
elif isinstance(message, Iterable):
|
||||
chunks = iter(message)
|
||||
try:
|
||||
chunk = next(chunks)
|
||||
except StopIteration:
|
||||
return
|
||||
|
||||
try:
|
||||
# First fragment.
|
||||
if isinstance(chunk, str):
|
||||
text = True
|
||||
with self.send_context():
|
||||
if self.send_in_progress:
|
||||
raise ConcurrencyError(
|
||||
"cannot call send while another thread "
|
||||
"is already running send"
|
||||
)
|
||||
self.send_in_progress = True
|
||||
self.protocol.send_text(
|
||||
chunk.encode(),
|
||||
fin=False,
|
||||
)
|
||||
elif isinstance(chunk, BytesLike):
|
||||
text = False
|
||||
with self.send_context():
|
||||
if self.send_in_progress:
|
||||
raise ConcurrencyError(
|
||||
"cannot call send while another thread "
|
||||
"is already running send"
|
||||
)
|
||||
self.send_in_progress = True
|
||||
self.protocol.send_binary(
|
||||
chunk,
|
||||
fin=False,
|
||||
)
|
||||
else:
|
||||
raise TypeError("data iterable must contain bytes or str")
|
||||
|
||||
# Other fragments
|
||||
for chunk in chunks:
|
||||
if isinstance(chunk, str) and text:
|
||||
with self.send_context():
|
||||
assert self.send_in_progress
|
||||
self.protocol.send_continuation(
|
||||
chunk.encode(),
|
||||
fin=False,
|
||||
)
|
||||
elif isinstance(chunk, BytesLike) and not text:
|
||||
with self.send_context():
|
||||
assert self.send_in_progress
|
||||
self.protocol.send_continuation(
|
||||
chunk,
|
||||
fin=False,
|
||||
)
|
||||
else:
|
||||
raise TypeError("data iterable must contain uniform types")
|
||||
|
||||
# Final fragment.
|
||||
with self.send_context():
|
||||
self.protocol.send_continuation(b"", fin=True)
|
||||
self.send_in_progress = False
|
||||
|
||||
except ConcurrencyError:
|
||||
# We didn't start sending a fragmented message.
|
||||
# The connection is still usable.
|
||||
raise
|
||||
|
||||
except Exception:
|
||||
# We're half-way through a fragmented message and we can't
|
||||
# complete it. This makes the connection unusable.
|
||||
with self.send_context():
|
||||
self.protocol.fail(
|
||||
CloseCode.INTERNAL_ERROR,
|
||||
"error in fragmented message",
|
||||
)
|
||||
raise
|
||||
|
||||
else:
|
||||
raise TypeError("data must be str, bytes, or iterable")
|
||||
|
||||
def close(self, code: int = CloseCode.NORMAL_CLOSURE, reason: str = "") -> None:
|
||||
"""
|
||||
Perform the closing handshake.
|
||||
|
||||
:meth:`close` waits for the other end to complete the handshake, for the
|
||||
TCP connection to terminate, and for all incoming messages to be read
|
||||
with :meth:`recv`.
|
||||
|
||||
:meth:`close` is idempotent: it doesn't do anything once the
|
||||
connection is closed.
|
||||
|
||||
Args:
|
||||
code: WebSocket close code.
|
||||
reason: WebSocket close reason.
|
||||
|
||||
"""
|
||||
try:
|
||||
# The context manager takes care of waiting for the TCP connection
|
||||
# to terminate after calling a method that sends a close frame.
|
||||
with self.send_context():
|
||||
if self.send_in_progress:
|
||||
self.protocol.fail(
|
||||
CloseCode.INTERNAL_ERROR,
|
||||
"close during fragmented message",
|
||||
)
|
||||
else:
|
||||
self.protocol.send_close(code, reason)
|
||||
except ConnectionClosed:
|
||||
# Ignore ConnectionClosed exceptions raised from send_context().
|
||||
# They mean that the connection is closed, which was the goal.
|
||||
pass
|
||||
|
||||
def ping(self, data: Data | None = None) -> threading.Event:
|
||||
"""
|
||||
Send a Ping_.
|
||||
|
||||
.. _Ping: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.2
|
||||
|
||||
A ping may serve as a keepalive or as a check that the remote endpoint
|
||||
received all messages up to this point
|
||||
|
||||
Args:
|
||||
data: Payload of the ping. A :class:`str` will be encoded to UTF-8.
|
||||
If ``data`` is :obj:`None`, the payload is four random bytes.
|
||||
|
||||
Returns:
|
||||
An event that will be set when the corresponding pong is received.
|
||||
You can ignore it if you don't intend to wait.
|
||||
|
||||
::
|
||||
|
||||
pong_event = ws.ping()
|
||||
pong_event.wait() # only if you want to wait for the pong
|
||||
|
||||
Raises:
|
||||
ConnectionClosed: When the connection is closed.
|
||||
ConcurrencyError: If another ping was sent with the same data and
|
||||
the corresponding pong wasn't received yet.
|
||||
|
||||
"""
|
||||
if isinstance(data, BytesLike):
|
||||
data = bytes(data)
|
||||
elif isinstance(data, str):
|
||||
data = data.encode()
|
||||
elif data is not None:
|
||||
raise TypeError("data must be str or bytes-like")
|
||||
|
||||
with self.send_context():
|
||||
# Protect against duplicates if a payload is explicitly set.
|
||||
if data in self.ping_waiters:
|
||||
raise ConcurrencyError("already waiting for a pong with the same data")
|
||||
|
||||
# Generate a unique random payload otherwise.
|
||||
while data is None or data in self.ping_waiters:
|
||||
data = struct.pack("!I", random.getrandbits(32))
|
||||
|
||||
pong_waiter = threading.Event()
|
||||
self.ping_waiters[data] = pong_waiter
|
||||
self.protocol.send_ping(data)
|
||||
return pong_waiter
|
||||
|
||||
def pong(self, data: Data = b"") -> None:
|
||||
"""
|
||||
Send a Pong_.
|
||||
|
||||
.. _Pong: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.3
|
||||
|
||||
An unsolicited pong may serve as a unidirectional heartbeat.
|
||||
|
||||
Args:
|
||||
data: Payload of the pong. A :class:`str` will be encoded to UTF-8.
|
||||
|
||||
Raises:
|
||||
ConnectionClosed: When the connection is closed.
|
||||
|
||||
"""
|
||||
if isinstance(data, BytesLike):
|
||||
data = bytes(data)
|
||||
elif isinstance(data, str):
|
||||
data = data.encode()
|
||||
else:
|
||||
raise TypeError("data must be str or bytes-like")
|
||||
|
||||
with self.send_context():
|
||||
self.protocol.send_pong(data)
|
||||
|
||||
# Private methods
|
||||
|
||||
def process_event(self, event: Event) -> None:
|
||||
"""
|
||||
Process one incoming event.
|
||||
|
||||
This method is overridden in subclasses to handle the handshake.
|
||||
|
||||
"""
|
||||
assert isinstance(event, Frame)
|
||||
if event.opcode in DATA_OPCODES:
|
||||
self.recv_messages.put(event)
|
||||
|
||||
if event.opcode is Opcode.PONG:
|
||||
self.acknowledge_pings(bytes(event.data))
|
||||
|
||||
def acknowledge_pings(self, data: bytes) -> None:
|
||||
"""
|
||||
Acknowledge pings when receiving a pong.
|
||||
|
||||
"""
|
||||
with self.protocol_mutex:
|
||||
# Ignore unsolicited pong.
|
||||
if data not in self.ping_waiters:
|
||||
return
|
||||
# Sending a pong for only the most recent ping is legal.
|
||||
# Acknowledge all previous pings too in that case.
|
||||
ping_id = None
|
||||
ping_ids = []
|
||||
for ping_id, ping in self.ping_waiters.items():
|
||||
ping_ids.append(ping_id)
|
||||
ping.set()
|
||||
if ping_id == data:
|
||||
break
|
||||
else:
|
||||
raise AssertionError("solicited pong not found in pings")
|
||||
# Remove acknowledged pings from self.ping_waiters.
|
||||
for ping_id in ping_ids:
|
||||
del self.ping_waiters[ping_id]
|
||||
|
||||
def recv_events(self) -> None:
|
||||
"""
|
||||
Read incoming data from the socket and process events.
|
||||
|
||||
Run this method in a thread as long as the connection is alive.
|
||||
|
||||
``recv_events()`` exits immediately when the ``self.socket`` is closed.
|
||||
|
||||
"""
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
if self.close_deadline is not None:
|
||||
self.socket.settimeout(self.close_deadline.timeout())
|
||||
data = self.socket.recv(self.recv_bufsize)
|
||||
except Exception as exc:
|
||||
if self.debug:
|
||||
self.logger.debug("error while receiving data", exc_info=True)
|
||||
# When the closing handshake is initiated by our side,
|
||||
# recv() may block until send_context() closes the socket.
|
||||
# In that case, send_context() already set recv_exc.
|
||||
# Calling set_recv_exc() avoids overwriting it.
|
||||
with self.protocol_mutex:
|
||||
self.set_recv_exc(exc)
|
||||
break
|
||||
|
||||
if data == b"":
|
||||
break
|
||||
|
||||
# Acquire the connection lock.
|
||||
with self.protocol_mutex:
|
||||
# Feed incoming data to the protocol.
|
||||
self.protocol.receive_data(data)
|
||||
|
||||
# This isn't expected to raise an exception.
|
||||
events = self.protocol.events_received()
|
||||
|
||||
# Write outgoing data to the socket.
|
||||
try:
|
||||
self.send_data()
|
||||
except Exception as exc:
|
||||
if self.debug:
|
||||
self.logger.debug("error while sending data", exc_info=True)
|
||||
# Similarly to the above, avoid overriding an exception
|
||||
# set by send_context(), in case of a race condition
|
||||
# i.e. send_context() closes the socket after recv()
|
||||
# returns above but before send_data() calls send().
|
||||
self.set_recv_exc(exc)
|
||||
break
|
||||
|
||||
if self.protocol.close_expected():
|
||||
# If the connection is expected to close soon, set the
|
||||
# close deadline based on the close timeout.
|
||||
if self.close_deadline is None:
|
||||
self.close_deadline = Deadline(self.close_timeout)
|
||||
|
||||
# Unlock conn_mutex before processing events. Else, the
|
||||
# application can't send messages in response to events.
|
||||
|
||||
# If self.send_data raised an exception, then events are lost.
|
||||
# Given that automatic responses write small amounts of data,
|
||||
# this should be uncommon, so we don't handle the edge case.
|
||||
|
||||
try:
|
||||
for event in events:
|
||||
# This may raise EOFError if the closing handshake
|
||||
# times out while a message is waiting to be read.
|
||||
self.process_event(event)
|
||||
except EOFError:
|
||||
break
|
||||
|
||||
# Breaking out of the while True: ... loop means that we believe
|
||||
# that the socket doesn't work anymore.
|
||||
with self.protocol_mutex:
|
||||
# Feed the end of the data stream to the protocol.
|
||||
self.protocol.receive_eof()
|
||||
|
||||
# This isn't expected to generate events.
|
||||
assert not self.protocol.events_received()
|
||||
|
||||
# There is no error handling because send_data() can only write
|
||||
# the end of the data stream here and it handles errors itself.
|
||||
self.send_data()
|
||||
|
||||
except Exception as exc:
|
||||
# This branch should never run. It's a safety net in case of bugs.
|
||||
self.logger.error("unexpected internal error", exc_info=True)
|
||||
with self.protocol_mutex:
|
||||
self.set_recv_exc(exc)
|
||||
# We don't know where we crashed. Force protocol state to CLOSED.
|
||||
self.protocol.state = CLOSED
|
||||
finally:
|
||||
# This isn't expected to raise an exception.
|
||||
self.close_socket()
|
||||
|
||||
@contextlib.contextmanager
|
||||
def send_context(
|
||||
self,
|
||||
*,
|
||||
expected_state: State = OPEN, # CONNECTING during the opening handshake
|
||||
) -> Iterator[None]:
|
||||
"""
|
||||
Create a context for writing to the connection from user code.
|
||||
|
||||
On entry, :meth:`send_context` acquires the connection lock and checks
|
||||
that the connection is open; on exit, it writes outgoing data to the
|
||||
socket::
|
||||
|
||||
with self.send_context():
|
||||
self.protocol.send_text(message.encode())
|
||||
|
||||
When the connection isn't open on entry, when the connection is expected
|
||||
to close on exit, or when an unexpected error happens, terminating the
|
||||
connection, :meth:`send_context` waits until the connection is closed
|
||||
then raises :exc:`~websockets.exceptions.ConnectionClosed`.
|
||||
|
||||
"""
|
||||
# Should we wait until the connection is closed?
|
||||
wait_for_close = False
|
||||
# Should we close the socket and raise ConnectionClosed?
|
||||
raise_close_exc = False
|
||||
# What exception should we chain ConnectionClosed to?
|
||||
original_exc: BaseException | None = None
|
||||
|
||||
# Acquire the protocol lock.
|
||||
with self.protocol_mutex:
|
||||
if self.protocol.state is expected_state:
|
||||
# Let the caller interact with the protocol.
|
||||
try:
|
||||
yield
|
||||
except (ProtocolError, ConcurrencyError):
|
||||
# The protocol state wasn't changed. Exit immediately.
|
||||
raise
|
||||
except Exception as exc:
|
||||
self.logger.error("unexpected internal error", exc_info=True)
|
||||
# This branch should never run. It's a safety net in case of
|
||||
# bugs. Since we don't know what happened, we will close the
|
||||
# connection and raise the exception to the caller.
|
||||
wait_for_close = False
|
||||
raise_close_exc = True
|
||||
original_exc = exc
|
||||
else:
|
||||
# Check if the connection is expected to close soon.
|
||||
if self.protocol.close_expected():
|
||||
wait_for_close = True
|
||||
# If the connection is expected to close soon, set the
|
||||
# close deadline based on the close timeout.
|
||||
# Since we tested earlier that protocol.state was OPEN
|
||||
# (or CONNECTING) and we didn't release protocol_mutex,
|
||||
# it is certain that self.close_deadline is still None.
|
||||
assert self.close_deadline is None
|
||||
self.close_deadline = Deadline(self.close_timeout)
|
||||
# Write outgoing data to the socket.
|
||||
try:
|
||||
self.send_data()
|
||||
except Exception as exc:
|
||||
if self.debug:
|
||||
self.logger.debug("error while sending data", exc_info=True)
|
||||
# While the only expected exception here is OSError,
|
||||
# other exceptions would be treated identically.
|
||||
wait_for_close = False
|
||||
raise_close_exc = True
|
||||
original_exc = exc
|
||||
|
||||
else: # self.protocol.state is not expected_state
|
||||
# Minor layering violation: we assume that the connection
|
||||
# will be closing soon if it isn't in the expected state.
|
||||
wait_for_close = True
|
||||
raise_close_exc = True
|
||||
|
||||
# To avoid a deadlock, release the connection lock by exiting the
|
||||
# context manager before waiting for recv_events() to terminate.
|
||||
|
||||
# If the connection is expected to close soon and the close timeout
|
||||
# elapses, close the socket to terminate the connection.
|
||||
if wait_for_close:
|
||||
if self.close_deadline is None:
|
||||
timeout = self.close_timeout
|
||||
else:
|
||||
# Thread.join() returns immediately if timeout is negative.
|
||||
timeout = self.close_deadline.timeout(raise_if_elapsed=False)
|
||||
self.recv_events_thread.join(timeout)
|
||||
|
||||
if self.recv_events_thread.is_alive():
|
||||
# There's no risk to overwrite another error because
|
||||
# original_exc is never set when wait_for_close is True.
|
||||
assert original_exc is None
|
||||
original_exc = TimeoutError("timed out while closing connection")
|
||||
# Set recv_exc before closing the socket in order to get
|
||||
# proper exception reporting.
|
||||
raise_close_exc = True
|
||||
with self.protocol_mutex:
|
||||
self.set_recv_exc(original_exc)
|
||||
|
||||
# If an error occurred, close the socket to terminate the connection and
|
||||
# raise an exception.
|
||||
if raise_close_exc:
|
||||
self.close_socket()
|
||||
self.recv_events_thread.join()
|
||||
raise self.protocol.close_exc from original_exc
|
||||
|
||||
def send_data(self) -> None:
|
||||
"""
|
||||
Send outgoing data.
|
||||
|
||||
This method requires holding protocol_mutex.
|
||||
|
||||
Raises:
|
||||
OSError: When a socket operations fails.
|
||||
|
||||
"""
|
||||
assert self.protocol_mutex.locked()
|
||||
for data in self.protocol.data_to_send():
|
||||
if data:
|
||||
if self.close_deadline is not None:
|
||||
self.socket.settimeout(self.close_deadline.timeout())
|
||||
self.socket.sendall(data)
|
||||
else:
|
||||
try:
|
||||
self.socket.shutdown(socket.SHUT_WR)
|
||||
except OSError: # socket already closed
|
||||
pass
|
||||
|
||||
def set_recv_exc(self, exc: BaseException | None) -> None:
|
||||
"""
|
||||
Set recv_exc, if not set yet.
|
||||
|
||||
This method requires holding protocol_mutex.
|
||||
|
||||
"""
|
||||
assert self.protocol_mutex.locked()
|
||||
if self.recv_exc is None: # pragma: no branch
|
||||
self.recv_exc = exc
|
||||
|
||||
def close_socket(self) -> None:
|
||||
"""
|
||||
Shutdown and close socket. Close message assembler.
|
||||
|
||||
Calling close_socket() guarantees that recv_events() terminates. Indeed,
|
||||
recv_events() may block only on socket.recv() or on recv_messages.put().
|
||||
|
||||
"""
|
||||
# shutdown() is required to interrupt recv() on Linux.
|
||||
try:
|
||||
self.socket.shutdown(socket.SHUT_RDWR)
|
||||
except OSError:
|
||||
pass # socket is already closed
|
||||
self.socket.close()
|
||||
self.recv_messages.close()
|
@ -1,283 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import codecs
|
||||
import queue
|
||||
import threading
|
||||
from typing import Iterator, cast
|
||||
|
||||
from ..exceptions import ConcurrencyError
|
||||
from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame
|
||||
from ..typing import Data
|
||||
|
||||
|
||||
__all__ = ["Assembler"]
|
||||
|
||||
UTF8Decoder = codecs.getincrementaldecoder("utf-8")
|
||||
|
||||
|
||||
class Assembler:
|
||||
"""
|
||||
Assemble messages from frames.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
# Serialize reads and writes -- except for reads via synchronization
|
||||
# primitives provided by the threading and queue modules.
|
||||
self.mutex = threading.Lock()
|
||||
|
||||
# We create a latch with two events to synchronize the production of
|
||||
# frames and the consumption of messages (or frames) without a buffer.
|
||||
# This design requires a switch between the library thread and the user
|
||||
# thread for each message; that shouldn't be a performance bottleneck.
|
||||
|
||||
# put() sets this event to tell get() that a message can be fetched.
|
||||
self.message_complete = threading.Event()
|
||||
# get() sets this event to let put() that the message was fetched.
|
||||
self.message_fetched = threading.Event()
|
||||
|
||||
# This flag prevents concurrent calls to get() by user code.
|
||||
self.get_in_progress = False
|
||||
# This flag prevents concurrent calls to put() by library code.
|
||||
self.put_in_progress = False
|
||||
|
||||
# Decoder for text frames, None for binary frames.
|
||||
self.decoder: codecs.IncrementalDecoder | None = None
|
||||
|
||||
# Buffer of frames belonging to the same message.
|
||||
self.chunks: list[Data] = []
|
||||
|
||||
# When switching from "buffering" to "streaming", we use a thread-safe
|
||||
# queue for transferring frames from the writing thread (library code)
|
||||
# to the reading thread (user code). We're buffering when chunks_queue
|
||||
# is None and streaming when it's a SimpleQueue. None is a sentinel
|
||||
# value marking the end of the message, superseding message_complete.
|
||||
|
||||
# Stream data from frames belonging to the same message.
|
||||
self.chunks_queue: queue.SimpleQueue[Data | None] | None = None
|
||||
|
||||
# This flag marks the end of the connection.
|
||||
self.closed = False
|
||||
|
||||
def get(self, timeout: float | None = None) -> Data:
|
||||
"""
|
||||
Read the next message.
|
||||
|
||||
:meth:`get` returns a single :class:`str` or :class:`bytes`.
|
||||
|
||||
If the message is fragmented, :meth:`get` waits until the last frame is
|
||||
received, then it reassembles the message and returns it. To receive
|
||||
messages frame by frame, use :meth:`get_iter` instead.
|
||||
|
||||
Args:
|
||||
timeout: If a timeout is provided and elapses before a complete
|
||||
message is received, :meth:`get` raises :exc:`TimeoutError`.
|
||||
|
||||
Raises:
|
||||
EOFError: If the stream of frames has ended.
|
||||
ConcurrencyError: If two threads run :meth:`get` or :meth:`get_iter`
|
||||
concurrently.
|
||||
TimeoutError: If a timeout is provided and elapses before a
|
||||
complete message is received.
|
||||
|
||||
"""
|
||||
with self.mutex:
|
||||
if self.closed:
|
||||
raise EOFError("stream of frames ended")
|
||||
|
||||
if self.get_in_progress:
|
||||
raise ConcurrencyError("get() or get_iter() is already running")
|
||||
|
||||
self.get_in_progress = True
|
||||
|
||||
# If the message_complete event isn't set yet, release the lock to
|
||||
# allow put() to run and eventually set it.
|
||||
# Locking with get_in_progress ensures only one thread can get here.
|
||||
completed = self.message_complete.wait(timeout)
|
||||
|
||||
with self.mutex:
|
||||
self.get_in_progress = False
|
||||
|
||||
# Waiting for a complete message timed out.
|
||||
if not completed:
|
||||
raise TimeoutError(f"timed out in {timeout:.1f}s")
|
||||
|
||||
# get() was unblocked by close() rather than put().
|
||||
if self.closed:
|
||||
raise EOFError("stream of frames ended")
|
||||
|
||||
assert self.message_complete.is_set()
|
||||
self.message_complete.clear()
|
||||
|
||||
joiner: Data = b"" if self.decoder is None else ""
|
||||
# mypy cannot figure out that chunks have the proper type.
|
||||
message: Data = joiner.join(self.chunks) # type: ignore
|
||||
|
||||
self.chunks = []
|
||||
assert self.chunks_queue is None
|
||||
|
||||
assert not self.message_fetched.is_set()
|
||||
self.message_fetched.set()
|
||||
|
||||
return message
|
||||
|
||||
def get_iter(self) -> Iterator[Data]:
|
||||
"""
|
||||
Stream the next message.
|
||||
|
||||
Iterating the return value of :meth:`get_iter` yields a :class:`str` or
|
||||
:class:`bytes` for each frame in the message.
|
||||
|
||||
The iterator must be fully consumed before calling :meth:`get_iter` or
|
||||
:meth:`get` again. Else, :exc:`ConcurrencyError` is raised.
|
||||
|
||||
This method only makes sense for fragmented messages. If messages aren't
|
||||
fragmented, use :meth:`get` instead.
|
||||
|
||||
Raises:
|
||||
EOFError: If the stream of frames has ended.
|
||||
ConcurrencyError: If two threads run :meth:`get` or :meth:`get_iter`
|
||||
concurrently.
|
||||
|
||||
"""
|
||||
with self.mutex:
|
||||
if self.closed:
|
||||
raise EOFError("stream of frames ended")
|
||||
|
||||
if self.get_in_progress:
|
||||
raise ConcurrencyError("get() or get_iter() is already running")
|
||||
|
||||
chunks = self.chunks
|
||||
self.chunks = []
|
||||
self.chunks_queue = cast(
|
||||
# Remove quotes around type when dropping Python < 3.9.
|
||||
"queue.SimpleQueue[Data | None]",
|
||||
queue.SimpleQueue(),
|
||||
)
|
||||
|
||||
# Sending None in chunk_queue supersedes setting message_complete
|
||||
# when switching to "streaming". If message is already complete
|
||||
# when the switch happens, put() didn't send None, so we have to.
|
||||
if self.message_complete.is_set():
|
||||
self.chunks_queue.put(None)
|
||||
|
||||
self.get_in_progress = True
|
||||
|
||||
# Locking with get_in_progress ensures only one thread can get here.
|
||||
chunk: Data | None
|
||||
for chunk in chunks:
|
||||
yield chunk
|
||||
while (chunk := self.chunks_queue.get()) is not None:
|
||||
yield chunk
|
||||
|
||||
with self.mutex:
|
||||
self.get_in_progress = False
|
||||
|
||||
# get_iter() was unblocked by close() rather than put().
|
||||
if self.closed:
|
||||
raise EOFError("stream of frames ended")
|
||||
|
||||
assert self.message_complete.is_set()
|
||||
self.message_complete.clear()
|
||||
|
||||
assert self.chunks == []
|
||||
self.chunks_queue = None
|
||||
|
||||
assert not self.message_fetched.is_set()
|
||||
self.message_fetched.set()
|
||||
|
||||
def put(self, frame: Frame) -> None:
|
||||
"""
|
||||
Add ``frame`` to the next message.
|
||||
|
||||
When ``frame`` is the final frame in a message, :meth:`put` waits until
|
||||
the message is fetched, which can be achieved by calling :meth:`get` or
|
||||
by fully consuming the return value of :meth:`get_iter`.
|
||||
|
||||
:meth:`put` assumes that the stream of frames respects the protocol. If
|
||||
it doesn't, the behavior is undefined.
|
||||
|
||||
Raises:
|
||||
EOFError: If the stream of frames has ended.
|
||||
ConcurrencyError: If two threads run :meth:`put` concurrently.
|
||||
|
||||
"""
|
||||
with self.mutex:
|
||||
if self.closed:
|
||||
raise EOFError("stream of frames ended")
|
||||
|
||||
if self.put_in_progress:
|
||||
raise ConcurrencyError("put is already running")
|
||||
|
||||
if frame.opcode is OP_TEXT:
|
||||
self.decoder = UTF8Decoder(errors="strict")
|
||||
elif frame.opcode is OP_BINARY:
|
||||
self.decoder = None
|
||||
else:
|
||||
assert frame.opcode is OP_CONT
|
||||
|
||||
data: Data
|
||||
if self.decoder is not None:
|
||||
data = self.decoder.decode(frame.data, frame.fin)
|
||||
else:
|
||||
data = frame.data
|
||||
|
||||
if self.chunks_queue is None:
|
||||
self.chunks.append(data)
|
||||
else:
|
||||
self.chunks_queue.put(data)
|
||||
|
||||
if not frame.fin:
|
||||
return
|
||||
|
||||
# Message is complete. Wait until it's fetched to return.
|
||||
|
||||
assert not self.message_complete.is_set()
|
||||
self.message_complete.set()
|
||||
|
||||
if self.chunks_queue is not None:
|
||||
self.chunks_queue.put(None)
|
||||
|
||||
assert not self.message_fetched.is_set()
|
||||
|
||||
self.put_in_progress = True
|
||||
|
||||
# Release the lock to allow get() to run and eventually set the event.
|
||||
# Locking with put_in_progress ensures only one coroutine can get here.
|
||||
self.message_fetched.wait()
|
||||
|
||||
with self.mutex:
|
||||
self.put_in_progress = False
|
||||
|
||||
# put() was unblocked by close() rather than get() or get_iter().
|
||||
if self.closed:
|
||||
raise EOFError("stream of frames ended")
|
||||
|
||||
assert self.message_fetched.is_set()
|
||||
self.message_fetched.clear()
|
||||
|
||||
self.decoder = None
|
||||
|
||||
def close(self) -> None:
|
||||
"""
|
||||
End the stream of frames.
|
||||
|
||||
Callling :meth:`close` concurrently with :meth:`get`, :meth:`get_iter`,
|
||||
or :meth:`put` is safe. They will raise :exc:`EOFError`.
|
||||
|
||||
"""
|
||||
with self.mutex:
|
||||
if self.closed:
|
||||
return
|
||||
|
||||
self.closed = True
|
||||
|
||||
# Unblock get or get_iter.
|
||||
if self.get_in_progress:
|
||||
self.message_complete.set()
|
||||
if self.chunks_queue is not None:
|
||||
self.chunks_queue.put(None)
|
||||
|
||||
# Unblock put().
|
||||
if self.put_in_progress:
|
||||
self.message_fetched.set()
|
@ -1,727 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hmac
|
||||
import http
|
||||
import logging
|
||||
import os
|
||||
import selectors
|
||||
import socket
|
||||
import ssl as ssl_module
|
||||
import sys
|
||||
import threading
|
||||
import warnings
|
||||
from types import TracebackType
|
||||
from typing import Any, Callable, Iterable, Sequence, Tuple, cast
|
||||
|
||||
from ..exceptions import InvalidHeader
|
||||
from ..extensions.base import ServerExtensionFactory
|
||||
from ..extensions.permessage_deflate import enable_server_permessage_deflate
|
||||
from ..frames import CloseCode
|
||||
from ..headers import (
|
||||
build_www_authenticate_basic,
|
||||
parse_authorization_basic,
|
||||
validate_subprotocols,
|
||||
)
|
||||
from ..http11 import SERVER, Request, Response
|
||||
from ..protocol import CONNECTING, OPEN, Event
|
||||
from ..server import ServerProtocol
|
||||
from ..typing import LoggerLike, Origin, StatusLike, Subprotocol
|
||||
from .connection import Connection
|
||||
from .utils import Deadline
|
||||
|
||||
|
||||
__all__ = ["serve", "unix_serve", "ServerConnection", "Server", "basic_auth"]
|
||||
|
||||
|
||||
class ServerConnection(Connection):
|
||||
"""
|
||||
:mod:`threading` implementation of a WebSocket server connection.
|
||||
|
||||
:class:`ServerConnection` provides :meth:`recv` and :meth:`send` methods for
|
||||
receiving and sending messages.
|
||||
|
||||
It supports iteration to receive messages::
|
||||
|
||||
for message in websocket:
|
||||
process(message)
|
||||
|
||||
The iterator exits normally when the connection is closed with close code
|
||||
1000 (OK) or 1001 (going away) or without a close code. It raises a
|
||||
:exc:`~websockets.exceptions.ConnectionClosedError` when the connection is
|
||||
closed with any other code.
|
||||
|
||||
Args:
|
||||
socket: Socket connected to a WebSocket client.
|
||||
protocol: Sans-I/O connection.
|
||||
close_timeout: Timeout for closing the connection in seconds.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
socket: socket.socket,
|
||||
protocol: ServerProtocol,
|
||||
*,
|
||||
close_timeout: float | None = 10,
|
||||
) -> None:
|
||||
self.protocol: ServerProtocol
|
||||
self.request_rcvd = threading.Event()
|
||||
super().__init__(
|
||||
socket,
|
||||
protocol,
|
||||
close_timeout=close_timeout,
|
||||
)
|
||||
self.username: str # see basic_auth()
|
||||
|
||||
def respond(self, status: StatusLike, text: str) -> Response:
|
||||
"""
|
||||
Create a plain text HTTP response.
|
||||
|
||||
``process_request`` and ``process_response`` may call this method to
|
||||
return an HTTP response instead of performing the WebSocket opening
|
||||
handshake.
|
||||
|
||||
You can modify the response before returning it, for example by changing
|
||||
HTTP headers.
|
||||
|
||||
Args:
|
||||
status: HTTP status code.
|
||||
text: HTTP response body; it will be encoded to UTF-8.
|
||||
|
||||
Returns:
|
||||
HTTP response to send to the client.
|
||||
|
||||
"""
|
||||
return self.protocol.reject(status, text)
|
||||
|
||||
def handshake(
|
||||
self,
|
||||
process_request: (
|
||||
Callable[
|
||||
[ServerConnection, Request],
|
||||
Response | None,
|
||||
]
|
||||
| None
|
||||
) = None,
|
||||
process_response: (
|
||||
Callable[
|
||||
[ServerConnection, Request, Response],
|
||||
Response | None,
|
||||
]
|
||||
| None
|
||||
) = None,
|
||||
server_header: str | None = SERVER,
|
||||
timeout: float | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Perform the opening handshake.
|
||||
|
||||
"""
|
||||
if not self.request_rcvd.wait(timeout):
|
||||
raise TimeoutError("timed out during handshake")
|
||||
|
||||
if self.request is not None:
|
||||
with self.send_context(expected_state=CONNECTING):
|
||||
response = None
|
||||
|
||||
if process_request is not None:
|
||||
try:
|
||||
response = process_request(self, self.request)
|
||||
except Exception as exc:
|
||||
self.protocol.handshake_exc = exc
|
||||
response = self.protocol.reject(
|
||||
http.HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||
(
|
||||
"Failed to open a WebSocket connection.\n"
|
||||
"See server log for more information.\n"
|
||||
),
|
||||
)
|
||||
|
||||
if response is None:
|
||||
self.response = self.protocol.accept(self.request)
|
||||
else:
|
||||
self.response = response
|
||||
|
||||
if server_header:
|
||||
self.response.headers["Server"] = server_header
|
||||
|
||||
response = None
|
||||
|
||||
if process_response is not None:
|
||||
try:
|
||||
response = process_response(self, self.request, self.response)
|
||||
except Exception as exc:
|
||||
self.protocol.handshake_exc = exc
|
||||
response = self.protocol.reject(
|
||||
http.HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||
(
|
||||
"Failed to open a WebSocket connection.\n"
|
||||
"See server log for more information.\n"
|
||||
),
|
||||
)
|
||||
|
||||
if response is not None:
|
||||
self.response = response
|
||||
|
||||
self.protocol.send_response(self.response)
|
||||
|
||||
# self.protocol.handshake_exc is always set when the connection is lost
|
||||
# before receiving a request, when the request cannot be parsed, when
|
||||
# the handshake encounters an error, or when process_request or
|
||||
# process_response sends a HTTP response that rejects the handshake.
|
||||
|
||||
if self.protocol.handshake_exc is not None:
|
||||
raise self.protocol.handshake_exc
|
||||
|
||||
def process_event(self, event: Event) -> None:
|
||||
"""
|
||||
Process one incoming event.
|
||||
|
||||
"""
|
||||
# First event - handshake request.
|
||||
if self.request is None:
|
||||
assert isinstance(event, Request)
|
||||
self.request = event
|
||||
self.request_rcvd.set()
|
||||
# Later events - frames.
|
||||
else:
|
||||
super().process_event(event)
|
||||
|
||||
def recv_events(self) -> None:
|
||||
"""
|
||||
Read incoming data from the socket and process events.
|
||||
|
||||
"""
|
||||
try:
|
||||
super().recv_events()
|
||||
finally:
|
||||
# If the connection is closed during the handshake, unblock it.
|
||||
self.request_rcvd.set()
|
||||
|
||||
|
||||
class Server:
|
||||
"""
|
||||
WebSocket server returned by :func:`serve`.
|
||||
|
||||
This class mirrors the API of :class:`~socketserver.BaseServer`, notably the
|
||||
:meth:`~socketserver.BaseServer.serve_forever` and
|
||||
:meth:`~socketserver.BaseServer.shutdown` methods, as well as the context
|
||||
manager protocol.
|
||||
|
||||
Args:
|
||||
socket: Server socket listening for new connections.
|
||||
handler: Handler for one connection. Receives the socket and address
|
||||
returned by :meth:`~socket.socket.accept`.
|
||||
logger: Logger for this server.
|
||||
It defaults to ``logging.getLogger("websockets.server")``.
|
||||
See the :doc:`logging guide <../../topics/logging>` for details.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
socket: socket.socket,
|
||||
handler: Callable[[socket.socket, Any], None],
|
||||
logger: LoggerLike | None = None,
|
||||
) -> None:
|
||||
self.socket = socket
|
||||
self.handler = handler
|
||||
if logger is None:
|
||||
logger = logging.getLogger("websockets.server")
|
||||
self.logger = logger
|
||||
if sys.platform != "win32":
|
||||
self.shutdown_watcher, self.shutdown_notifier = os.pipe()
|
||||
|
||||
def serve_forever(self) -> None:
|
||||
"""
|
||||
See :meth:`socketserver.BaseServer.serve_forever`.
|
||||
|
||||
This method doesn't return. Calling :meth:`shutdown` from another thread
|
||||
stops the server.
|
||||
|
||||
Typical use::
|
||||
|
||||
with serve(...) as server:
|
||||
server.serve_forever()
|
||||
|
||||
"""
|
||||
poller = selectors.DefaultSelector()
|
||||
try:
|
||||
poller.register(self.socket, selectors.EVENT_READ)
|
||||
except ValueError: # pragma: no cover
|
||||
# If shutdown() is called before poller.register(),
|
||||
# the socket is closed and poller.register() raises
|
||||
# ValueError: Invalid file descriptor: -1
|
||||
return
|
||||
if sys.platform != "win32":
|
||||
poller.register(self.shutdown_watcher, selectors.EVENT_READ)
|
||||
|
||||
while True:
|
||||
poller.select()
|
||||
try:
|
||||
# If the socket is closed, this will raise an exception and exit
|
||||
# the loop. So we don't need to check the return value of select().
|
||||
sock, addr = self.socket.accept()
|
||||
except OSError:
|
||||
break
|
||||
# Since there isn't a mechanism for tracking connections and waiting
|
||||
# for them to terminate, we cannot use daemon threads, or else all
|
||||
# connections would be terminate brutally when closing the server.
|
||||
thread = threading.Thread(target=self.handler, args=(sock, addr))
|
||||
thread.start()
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""
|
||||
See :meth:`socketserver.BaseServer.shutdown`.
|
||||
|
||||
"""
|
||||
self.socket.close()
|
||||
if sys.platform != "win32":
|
||||
os.write(self.shutdown_notifier, b"x")
|
||||
|
||||
def fileno(self) -> int:
|
||||
"""
|
||||
See :meth:`socketserver.BaseServer.fileno`.
|
||||
|
||||
"""
|
||||
return self.socket.fileno()
|
||||
|
||||
def __enter__(self) -> Server:
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
traceback: TracebackType | None,
|
||||
) -> None:
|
||||
self.shutdown()
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
if name == "WebSocketServer":
|
||||
warnings.warn( # deprecated in 13.0 - 2024-08-20
|
||||
"WebSocketServer was renamed to Server",
|
||||
DeprecationWarning,
|
||||
)
|
||||
return Server
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
|
||||
def serve(
|
||||
handler: Callable[[ServerConnection], None],
|
||||
host: str | None = None,
|
||||
port: int | None = None,
|
||||
*,
|
||||
# TCP/TLS
|
||||
sock: socket.socket | None = None,
|
||||
ssl: ssl_module.SSLContext | None = None,
|
||||
# WebSocket
|
||||
origins: Sequence[Origin | None] | None = None,
|
||||
extensions: Sequence[ServerExtensionFactory] | None = None,
|
||||
subprotocols: Sequence[Subprotocol] | None = None,
|
||||
select_subprotocol: (
|
||||
Callable[
|
||||
[ServerConnection, Sequence[Subprotocol]],
|
||||
Subprotocol | None,
|
||||
]
|
||||
| None
|
||||
) = None,
|
||||
process_request: (
|
||||
Callable[
|
||||
[ServerConnection, Request],
|
||||
Response | None,
|
||||
]
|
||||
| None
|
||||
) = None,
|
||||
process_response: (
|
||||
Callable[
|
||||
[ServerConnection, Request, Response],
|
||||
Response | None,
|
||||
]
|
||||
| None
|
||||
) = None,
|
||||
server_header: str | None = SERVER,
|
||||
compression: str | None = "deflate",
|
||||
# Timeouts
|
||||
open_timeout: float | None = 10,
|
||||
close_timeout: float | None = 10,
|
||||
# Limits
|
||||
max_size: int | None = 2**20,
|
||||
# Logging
|
||||
logger: LoggerLike | None = None,
|
||||
# Escape hatch for advanced customization
|
||||
create_connection: type[ServerConnection] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Server:
|
||||
"""
|
||||
Create a WebSocket server listening on ``host`` and ``port``.
|
||||
|
||||
Whenever a client connects, the server creates a :class:`ServerConnection`,
|
||||
performs the opening handshake, and delegates to the ``handler``.
|
||||
|
||||
The handler receives the :class:`ServerConnection` instance, which you can
|
||||
use to send and receive messages.
|
||||
|
||||
Once the handler completes, either normally or with an exception, the server
|
||||
performs the closing handshake and closes the connection.
|
||||
|
||||
This function returns a :class:`Server` whose API mirrors
|
||||
:class:`~socketserver.BaseServer`. Treat it as a context manager to ensure
|
||||
that it will be closed and call :meth:`~Server.serve_forever` to serve
|
||||
requests::
|
||||
|
||||
from websockets.sync.server import serve
|
||||
|
||||
def handler(websocket):
|
||||
...
|
||||
|
||||
with serve(handler, ...) as server:
|
||||
server.serve_forever()
|
||||
|
||||
Args:
|
||||
handler: Connection handler. It receives the WebSocket connection,
|
||||
which is a :class:`ServerConnection`, in argument.
|
||||
host: Network interfaces the server binds to.
|
||||
See :func:`~socket.create_server` for details.
|
||||
port: TCP port the server listens on.
|
||||
See :func:`~socket.create_server` for details.
|
||||
sock: Preexisting TCP socket. ``sock`` replaces ``host`` and ``port``.
|
||||
You may call :func:`socket.create_server` to create a suitable TCP
|
||||
socket.
|
||||
ssl: Configuration for enabling TLS on the connection.
|
||||
origins: Acceptable values of the ``Origin`` header, for defending
|
||||
against Cross-Site WebSocket Hijacking attacks. Include :obj:`None`
|
||||
in the list if the lack of an origin is acceptable.
|
||||
extensions: List of supported extensions, in order in which they
|
||||
should be negotiated and run.
|
||||
subprotocols: List of supported subprotocols, in order of decreasing
|
||||
preference.
|
||||
select_subprotocol: Callback for selecting a subprotocol among
|
||||
those supported by the client and the server. It receives a
|
||||
:class:`ServerConnection` (not a
|
||||
:class:`~websockets.server.ServerProtocol`!) instance and a list of
|
||||
subprotocols offered by the client. Other than the first argument,
|
||||
it has the same behavior as the
|
||||
:meth:`ServerProtocol.select_subprotocol
|
||||
<websockets.server.ServerProtocol.select_subprotocol>` method.
|
||||
process_request: Intercept the request during the opening handshake.
|
||||
Return an HTTP response to force the response. Return :obj:`None` to
|
||||
continue normally. When you force an HTTP 101 Continue response, the
|
||||
handshake is successful. Else, the connection is aborted.
|
||||
process_response: Intercept the response during the opening handshake.
|
||||
Modify the response or return a new HTTP response to force the
|
||||
response. Return :obj:`None` to continue normally. When you force an
|
||||
HTTP 101 Continue response, the handshake is successful. Else, the
|
||||
connection is aborted.
|
||||
server_header: Value of the ``Server`` response header.
|
||||
It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to
|
||||
:obj:`None` removes the header.
|
||||
compression: The "permessage-deflate" extension is enabled by default.
|
||||
Set ``compression`` to :obj:`None` to disable it. See the
|
||||
:doc:`compression guide <../../topics/compression>` for details.
|
||||
open_timeout: Timeout for opening connections in seconds.
|
||||
:obj:`None` disables the timeout.
|
||||
close_timeout: Timeout for closing connections in seconds.
|
||||
:obj:`None` disables the timeout.
|
||||
max_size: Maximum size of incoming messages in bytes.
|
||||
:obj:`None` disables the limit.
|
||||
logger: Logger for this server.
|
||||
It defaults to ``logging.getLogger("websockets.server")``. See the
|
||||
:doc:`logging guide <../../topics/logging>` for details.
|
||||
create_connection: Factory for the :class:`ServerConnection` managing
|
||||
the connection. Set it to a wrapper or a subclass to customize
|
||||
connection handling.
|
||||
|
||||
Any other keyword arguments are passed to :func:`~socket.create_server`.
|
||||
|
||||
"""
|
||||
|
||||
# Process parameters
|
||||
|
||||
# Backwards compatibility: ssl used to be called ssl_context.
|
||||
if ssl is None and "ssl_context" in kwargs:
|
||||
ssl = kwargs.pop("ssl_context")
|
||||
warnings.warn( # deprecated in 13.0 - 2024-08-20
|
||||
"ssl_context was renamed to ssl",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
if subprotocols is not None:
|
||||
validate_subprotocols(subprotocols)
|
||||
|
||||
if compression == "deflate":
|
||||
extensions = enable_server_permessage_deflate(extensions)
|
||||
elif compression is not None:
|
||||
raise ValueError(f"unsupported compression: {compression}")
|
||||
|
||||
if create_connection is None:
|
||||
create_connection = ServerConnection
|
||||
|
||||
# Bind socket and listen
|
||||
|
||||
# Private APIs for unix_connect()
|
||||
unix: bool = kwargs.pop("unix", False)
|
||||
path: str | None = kwargs.pop("path", None)
|
||||
|
||||
if sock is None:
|
||||
if unix:
|
||||
if path is None:
|
||||
raise TypeError("missing path argument")
|
||||
kwargs.setdefault("family", socket.AF_UNIX)
|
||||
sock = socket.create_server(path, **kwargs)
|
||||
else:
|
||||
sock = socket.create_server((host, port), **kwargs)
|
||||
else:
|
||||
if path is not None:
|
||||
raise TypeError("path and sock arguments are incompatible")
|
||||
|
||||
# Initialize TLS wrapper
|
||||
|
||||
if ssl is not None:
|
||||
sock = ssl.wrap_socket(
|
||||
sock,
|
||||
server_side=True,
|
||||
# Delay TLS handshake until after we set a timeout on the socket.
|
||||
do_handshake_on_connect=False,
|
||||
)
|
||||
|
||||
# Define request handler
|
||||
|
||||
def conn_handler(sock: socket.socket, addr: Any) -> None:
|
||||
# Calculate timeouts on the TLS and WebSocket handshakes.
|
||||
# The TLS timeout must be set on the socket, then removed
|
||||
# to avoid conflicting with the WebSocket timeout in handshake().
|
||||
deadline = Deadline(open_timeout)
|
||||
|
||||
try:
|
||||
# Disable Nagle algorithm
|
||||
|
||||
if not unix:
|
||||
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, True)
|
||||
|
||||
# Perform TLS handshake
|
||||
|
||||
if ssl is not None:
|
||||
sock.settimeout(deadline.timeout())
|
||||
# mypy cannot figure this out
|
||||
assert isinstance(sock, ssl_module.SSLSocket)
|
||||
sock.do_handshake()
|
||||
sock.settimeout(None)
|
||||
|
||||
# Create a closure to give select_subprotocol access to connection.
|
||||
protocol_select_subprotocol: (
|
||||
Callable[
|
||||
[ServerProtocol, Sequence[Subprotocol]],
|
||||
Subprotocol | None,
|
||||
]
|
||||
| None
|
||||
) = None
|
||||
if select_subprotocol is not None:
|
||||
|
||||
def protocol_select_subprotocol(
|
||||
protocol: ServerProtocol,
|
||||
subprotocols: Sequence[Subprotocol],
|
||||
) -> Subprotocol | None:
|
||||
# mypy doesn't know that select_subprotocol is immutable.
|
||||
assert select_subprotocol is not None
|
||||
# Ensure this function is only used in the intended context.
|
||||
assert protocol is connection.protocol
|
||||
return select_subprotocol(connection, subprotocols)
|
||||
|
||||
# Initialize WebSocket protocol
|
||||
|
||||
protocol = ServerProtocol(
|
||||
origins=origins,
|
||||
extensions=extensions,
|
||||
subprotocols=subprotocols,
|
||||
select_subprotocol=protocol_select_subprotocol,
|
||||
max_size=max_size,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
# Initialize WebSocket connection
|
||||
|
||||
assert create_connection is not None # help mypy
|
||||
connection = create_connection(
|
||||
sock,
|
||||
protocol,
|
||||
close_timeout=close_timeout,
|
||||
)
|
||||
except Exception:
|
||||
sock.close()
|
||||
return
|
||||
|
||||
try:
|
||||
try:
|
||||
connection.handshake(
|
||||
process_request,
|
||||
process_response,
|
||||
server_header,
|
||||
deadline.timeout(),
|
||||
)
|
||||
except TimeoutError:
|
||||
connection.close_socket()
|
||||
connection.recv_events_thread.join()
|
||||
return
|
||||
except Exception:
|
||||
connection.logger.error("opening handshake failed", exc_info=True)
|
||||
connection.close_socket()
|
||||
connection.recv_events_thread.join()
|
||||
return
|
||||
|
||||
assert connection.protocol.state is OPEN
|
||||
try:
|
||||
handler(connection)
|
||||
except Exception:
|
||||
connection.logger.error("connection handler failed", exc_info=True)
|
||||
connection.close(CloseCode.INTERNAL_ERROR)
|
||||
else:
|
||||
connection.close()
|
||||
|
||||
except Exception: # pragma: no cover
|
||||
# Don't leak sockets on unexpected errors.
|
||||
sock.close()
|
||||
|
||||
# Initialize server
|
||||
|
||||
return Server(sock, conn_handler, logger)
|
||||
|
||||
|
||||
def unix_serve(
|
||||
handler: Callable[[ServerConnection], None],
|
||||
path: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Server:
|
||||
"""
|
||||
Create a WebSocket server listening on a Unix socket.
|
||||
|
||||
This function accepts the same keyword arguments as :func:`serve`.
|
||||
|
||||
It's only available on Unix.
|
||||
|
||||
It's useful for deploying a server behind a reverse proxy such as nginx.
|
||||
|
||||
Args:
|
||||
handler: Connection handler. It receives the WebSocket connection,
|
||||
which is a :class:`ServerConnection`, in argument.
|
||||
path: File system path to the Unix socket.
|
||||
|
||||
"""
|
||||
return serve(handler, unix=True, path=path, **kwargs)
|
||||
|
||||
|
||||
def is_credentials(credentials: Any) -> bool:
|
||||
try:
|
||||
username, password = credentials
|
||||
except (TypeError, ValueError):
|
||||
return False
|
||||
else:
|
||||
return isinstance(username, str) and isinstance(password, str)
|
||||
|
||||
|
||||
def basic_auth(
|
||||
realm: str = "",
|
||||
credentials: tuple[str, str] | Iterable[tuple[str, str]] | None = None,
|
||||
check_credentials: Callable[[str, str], bool] | None = None,
|
||||
) -> Callable[[ServerConnection, Request], Response | None]:
|
||||
"""
|
||||
Factory for ``process_request`` to enforce HTTP Basic Authentication.
|
||||
|
||||
:func:`basic_auth` is designed to integrate with :func:`serve` as follows::
|
||||
|
||||
from websockets.sync.server import basic_auth, serve
|
||||
|
||||
with serve(
|
||||
...,
|
||||
process_request=basic_auth(
|
||||
realm="my dev server",
|
||||
credentials=("hello", "iloveyou"),
|
||||
),
|
||||
):
|
||||
|
||||
If authentication succeeds, the connection's ``username`` attribute is set.
|
||||
If it fails, the server responds with an HTTP 401 Unauthorized status.
|
||||
|
||||
One of ``credentials`` or ``check_credentials`` must be provided; not both.
|
||||
|
||||
Args:
|
||||
realm: Scope of protection. It should contain only ASCII characters
|
||||
because the encoding of non-ASCII characters is undefined. Refer to
|
||||
section 2.2 of :rfc:`7235` for details.
|
||||
credentials: Hard coded authorized credentials. It can be a
|
||||
``(username, password)`` pair or a list of such pairs.
|
||||
check_credentials: Function that verifies credentials.
|
||||
It receives ``username`` and ``password`` arguments and returns
|
||||
whether they're valid.
|
||||
Raises:
|
||||
TypeError: If ``credentials`` or ``check_credentials`` is wrong.
|
||||
|
||||
"""
|
||||
if (credentials is None) == (check_credentials is None):
|
||||
raise TypeError("provide either credentials or check_credentials")
|
||||
|
||||
if credentials is not None:
|
||||
if is_credentials(credentials):
|
||||
credentials_list = [cast(Tuple[str, str], credentials)]
|
||||
elif isinstance(credentials, Iterable):
|
||||
credentials_list = list(cast(Iterable[Tuple[str, str]], credentials))
|
||||
if not all(is_credentials(item) for item in credentials_list):
|
||||
raise TypeError(f"invalid credentials argument: {credentials}")
|
||||
else:
|
||||
raise TypeError(f"invalid credentials argument: {credentials}")
|
||||
|
||||
credentials_dict = dict(credentials_list)
|
||||
|
||||
def check_credentials(username: str, password: str) -> bool:
|
||||
try:
|
||||
expected_password = credentials_dict[username]
|
||||
except KeyError:
|
||||
return False
|
||||
return hmac.compare_digest(expected_password, password)
|
||||
|
||||
assert check_credentials is not None # help mypy
|
||||
|
||||
def process_request(
|
||||
connection: ServerConnection,
|
||||
request: Request,
|
||||
) -> Response | None:
|
||||
"""
|
||||
Perform HTTP Basic Authentication.
|
||||
|
||||
If it succeeds, set the connection's ``username`` attribute and return
|
||||
:obj:`None`. If it fails, return an HTTP 401 Unauthorized responss.
|
||||
|
||||
"""
|
||||
try:
|
||||
authorization = request.headers["Authorization"]
|
||||
except KeyError:
|
||||
response = connection.respond(
|
||||
http.HTTPStatus.UNAUTHORIZED,
|
||||
"Missing credentials\n",
|
||||
)
|
||||
response.headers["WWW-Authenticate"] = build_www_authenticate_basic(realm)
|
||||
return response
|
||||
|
||||
try:
|
||||
username, password = parse_authorization_basic(authorization)
|
||||
except InvalidHeader:
|
||||
response = connection.respond(
|
||||
http.HTTPStatus.UNAUTHORIZED,
|
||||
"Unsupported credentials\n",
|
||||
)
|
||||
response.headers["WWW-Authenticate"] = build_www_authenticate_basic(realm)
|
||||
return response
|
||||
|
||||
if not check_credentials(username, password):
|
||||
response = connection.respond(
|
||||
http.HTTPStatus.UNAUTHORIZED,
|
||||
"Invalid credentials\n",
|
||||
)
|
||||
response.headers["WWW-Authenticate"] = build_www_authenticate_basic(realm)
|
||||
return response
|
||||
|
||||
connection.username = username
|
||||
return None
|
||||
|
||||
return process_request
|
@ -1,45 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
|
||||
|
||||
__all__ = ["Deadline"]
|
||||
|
||||
|
||||
class Deadline:
|
||||
"""
|
||||
Manage timeouts across multiple steps.
|
||||
|
||||
Args:
|
||||
timeout: Time available in seconds or :obj:`None` if there is no limit.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, timeout: float | None) -> None:
|
||||
self.deadline: float | None
|
||||
if timeout is None:
|
||||
self.deadline = None
|
||||
else:
|
||||
self.deadline = time.monotonic() + timeout
|
||||
|
||||
def timeout(self, *, raise_if_elapsed: bool = True) -> float | None:
|
||||
"""
|
||||
Calculate a timeout from a deadline.
|
||||
|
||||
Args:
|
||||
raise_if_elapsed: Whether to raise :exc:`TimeoutError`
|
||||
if the deadline lapsed.
|
||||
|
||||
Raises:
|
||||
TimeoutError: If the deadline lapsed.
|
||||
|
||||
Returns:
|
||||
Time left in seconds or :obj:`None` if there is no limit.
|
||||
|
||||
"""
|
||||
if self.deadline is None:
|
||||
return None
|
||||
timeout = self.deadline - time.monotonic()
|
||||
if raise_if_elapsed and timeout <= 0:
|
||||
raise TimeoutError("timed out")
|
||||
return timeout
|
@ -1,77 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import http
|
||||
import logging
|
||||
import typing
|
||||
from typing import Any, List, NewType, Optional, Tuple, Union
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Data",
|
||||
"LoggerLike",
|
||||
"StatusLike",
|
||||
"Origin",
|
||||
"Subprotocol",
|
||||
"ExtensionName",
|
||||
"ExtensionParameter",
|
||||
]
|
||||
|
||||
|
||||
# Public types used in the signature of public APIs
|
||||
|
||||
# Change to str | bytes when dropping Python < 3.10.
|
||||
Data = Union[str, bytes]
|
||||
"""Types supported in a WebSocket message:
|
||||
:class:`str` for a Text_ frame, :class:`bytes` for a Binary_.
|
||||
|
||||
.. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6
|
||||
.. _Binary : https://datatracker.ietf.org/doc/html/rfc6455#section-5.6
|
||||
|
||||
"""
|
||||
|
||||
|
||||
# Change to logging.Logger | ... when dropping Python < 3.10.
|
||||
if typing.TYPE_CHECKING:
|
||||
LoggerLike = Union[logging.Logger, logging.LoggerAdapter[Any]]
|
||||
"""Types accepted where a :class:`~logging.Logger` is expected."""
|
||||
else: # remove this branch when dropping support for Python < 3.11
|
||||
LoggerLike = Union[logging.Logger, logging.LoggerAdapter]
|
||||
"""Types accepted where a :class:`~logging.Logger` is expected."""
|
||||
|
||||
|
||||
# Change to http.HTTPStatus | int when dropping Python < 3.10.
|
||||
StatusLike = Union[http.HTTPStatus, int]
|
||||
"""
|
||||
Types accepted where an :class:`~http.HTTPStatus` is expected."""
|
||||
|
||||
|
||||
Origin = NewType("Origin", str)
|
||||
"""Value of a ``Origin`` header."""
|
||||
|
||||
|
||||
Subprotocol = NewType("Subprotocol", str)
|
||||
"""Subprotocol in a ``Sec-WebSocket-Protocol`` header."""
|
||||
|
||||
|
||||
ExtensionName = NewType("ExtensionName", str)
|
||||
"""Name of a WebSocket extension."""
|
||||
|
||||
# Change to tuple[str, Optional[str]] when dropping Python < 3.9.
|
||||
# Change to tuple[str, str | None] when dropping Python < 3.10.
|
||||
ExtensionParameter = Tuple[str, Optional[str]]
|
||||
"""Parameter of a WebSocket extension."""
|
||||
|
||||
|
||||
# Private types
|
||||
|
||||
# Change to tuple[.., list[...]] when dropping Python < 3.9.
|
||||
ExtensionHeader = Tuple[ExtensionName, List[ExtensionParameter]]
|
||||
"""Extension in a ``Sec-WebSocket-Extensions`` header."""
|
||||
|
||||
|
||||
ConnectionOption = NewType("ConnectionOption", str)
|
||||
"""Connection option in a ``Connection`` header."""
|
||||
|
||||
|
||||
UpgradeProtocol = NewType("UpgradeProtocol", str)
|
||||
"""Upgrade protocol in an ``Upgrade`` header."""
|
@ -1,107 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import urllib.parse
|
||||
|
||||
from .exceptions import InvalidURI
|
||||
|
||||
|
||||
__all__ = ["parse_uri", "WebSocketURI"]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class WebSocketURI:
|
||||
"""
|
||||
WebSocket URI.
|
||||
|
||||
Attributes:
|
||||
secure: :obj:`True` for a ``wss`` URI, :obj:`False` for a ``ws`` URI.
|
||||
host: Normalized to lower case.
|
||||
port: Always set even if it's the default.
|
||||
path: May be empty.
|
||||
query: May be empty if the URI doesn't include a query component.
|
||||
username: Available when the URI contains `User Information`_.
|
||||
password: Available when the URI contains `User Information`_.
|
||||
|
||||
.. _User Information: https://datatracker.ietf.org/doc/html/rfc3986#section-3.2.1
|
||||
|
||||
"""
|
||||
|
||||
secure: bool
|
||||
host: str
|
||||
port: int
|
||||
path: str
|
||||
query: str
|
||||
username: str | None = None
|
||||
password: str | None = None
|
||||
|
||||
@property
|
||||
def resource_name(self) -> str:
|
||||
if self.path:
|
||||
resource_name = self.path
|
||||
else:
|
||||
resource_name = "/"
|
||||
if self.query:
|
||||
resource_name += "?" + self.query
|
||||
return resource_name
|
||||
|
||||
@property
|
||||
def user_info(self) -> tuple[str, str] | None:
|
||||
if self.username is None:
|
||||
return None
|
||||
assert self.password is not None
|
||||
return (self.username, self.password)
|
||||
|
||||
|
||||
# All characters from the gen-delims and sub-delims sets in RFC 3987.
|
||||
DELIMS = ":/?#[]@!$&'()*+,;="
|
||||
|
||||
|
||||
def parse_uri(uri: str) -> WebSocketURI:
|
||||
"""
|
||||
Parse and validate a WebSocket URI.
|
||||
|
||||
Args:
|
||||
uri: WebSocket URI.
|
||||
|
||||
Returns:
|
||||
Parsed WebSocket URI.
|
||||
|
||||
Raises:
|
||||
InvalidURI: If ``uri`` isn't a valid WebSocket URI.
|
||||
|
||||
"""
|
||||
parsed = urllib.parse.urlparse(uri)
|
||||
if parsed.scheme not in ["ws", "wss"]:
|
||||
raise InvalidURI(uri, "scheme isn't ws or wss")
|
||||
if parsed.hostname is None:
|
||||
raise InvalidURI(uri, "hostname isn't provided")
|
||||
if parsed.fragment != "":
|
||||
raise InvalidURI(uri, "fragment identifier is meaningless")
|
||||
|
||||
secure = parsed.scheme == "wss"
|
||||
host = parsed.hostname
|
||||
port = parsed.port or (443 if secure else 80)
|
||||
path = parsed.path
|
||||
query = parsed.query
|
||||
username = parsed.username
|
||||
password = parsed.password
|
||||
# urllib.parse.urlparse accepts URLs with a username but without a
|
||||
# password. This doesn't make sense for HTTP Basic Auth credentials.
|
||||
if username is not None and password is None:
|
||||
raise InvalidURI(uri, "username provided without password")
|
||||
|
||||
try:
|
||||
uri.encode("ascii")
|
||||
except UnicodeEncodeError:
|
||||
# Input contains non-ASCII characters.
|
||||
# It must be an IRI. Convert it to a URI.
|
||||
host = host.encode("idna").decode()
|
||||
path = urllib.parse.quote(path, safe=DELIMS)
|
||||
query = urllib.parse.quote(query, safe=DELIMS)
|
||||
if username is not None:
|
||||
assert password is not None
|
||||
username = urllib.parse.quote(username, safe=DELIMS)
|
||||
password = urllib.parse.quote(password, safe=DELIMS)
|
||||
|
||||
return WebSocketURI(secure, host, port, path, query, username, password)
|
@ -1,51 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import secrets
|
||||
import sys
|
||||
|
||||
|
||||
__all__ = ["accept_key", "apply_mask"]
|
||||
|
||||
|
||||
GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
|
||||
|
||||
|
||||
def generate_key() -> str:
|
||||
"""
|
||||
Generate a random key for the Sec-WebSocket-Key header.
|
||||
|
||||
"""
|
||||
key = secrets.token_bytes(16)
|
||||
return base64.b64encode(key).decode()
|
||||
|
||||
|
||||
def accept_key(key: str) -> str:
|
||||
"""
|
||||
Compute the value of the Sec-WebSocket-Accept header.
|
||||
|
||||
Args:
|
||||
key: Value of the Sec-WebSocket-Key header.
|
||||
|
||||
"""
|
||||
sha1 = hashlib.sha1((key + GUID).encode()).digest()
|
||||
return base64.b64encode(sha1).decode()
|
||||
|
||||
|
||||
def apply_mask(data: bytes, mask: bytes) -> bytes:
|
||||
"""
|
||||
Apply masking to the data of a WebSocket message.
|
||||
|
||||
Args:
|
||||
data: Data to mask.
|
||||
mask: 4-bytes mask.
|
||||
|
||||
"""
|
||||
if len(mask) != 4:
|
||||
raise ValueError("mask must contain 4 bytes")
|
||||
|
||||
data_int = int.from_bytes(data, sys.byteorder)
|
||||
mask_repeated = mask * (len(data) // 4) + mask[: len(data) % 4]
|
||||
mask_int = int.from_bytes(mask_repeated, sys.byteorder)
|
||||
return (data_int ^ mask_int).to_bytes(len(data), sys.byteorder)
|
@ -1,92 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.metadata
|
||||
|
||||
|
||||
__all__ = ["tag", "version", "commit"]
|
||||
|
||||
|
||||
# ========= =========== ===================
|
||||
# release development
|
||||
# ========= =========== ===================
|
||||
# tag X.Y X.Y (upcoming)
|
||||
# version X.Y X.Y.dev1+g5678cde
|
||||
# commit X.Y 5678cde
|
||||
# ========= =========== ===================
|
||||
|
||||
|
||||
# When tagging a release, set `released = True`.
|
||||
# After tagging a release, set `released = False` and increment `tag`.
|
||||
|
||||
released = False
|
||||
|
||||
tag = version = commit = "13.1"
|
||||
|
||||
|
||||
if not released: # pragma: no cover
|
||||
import pathlib
|
||||
import re
|
||||
import subprocess
|
||||
|
||||
def get_version(tag: str) -> str:
|
||||
# Since setup.py executes the contents of src/websockets/version.py,
|
||||
# __file__ can point to either of these two files.
|
||||
file_path = pathlib.Path(__file__)
|
||||
root_dir = file_path.parents[0 if file_path.name == "setup.py" else 2]
|
||||
|
||||
# Read version from package metadata if it is installed.
|
||||
try:
|
||||
version = importlib.metadata.version("websockets")
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
# Check that this file belongs to the installed package.
|
||||
files = importlib.metadata.files("websockets")
|
||||
if files:
|
||||
version_files = [f for f in files if f.name == file_path.name]
|
||||
if version_files:
|
||||
version_file = version_files[0]
|
||||
if version_file.locate() == file_path:
|
||||
return version
|
||||
|
||||
# Read version from git if available.
|
||||
try:
|
||||
description = subprocess.run(
|
||||
["git", "describe", "--dirty", "--tags", "--long"],
|
||||
capture_output=True,
|
||||
cwd=root_dir,
|
||||
timeout=1,
|
||||
check=True,
|
||||
text=True,
|
||||
).stdout.strip()
|
||||
# subprocess.run raises FileNotFoundError if git isn't on $PATH.
|
||||
except (
|
||||
FileNotFoundError,
|
||||
subprocess.CalledProcessError,
|
||||
subprocess.TimeoutExpired,
|
||||
):
|
||||
pass
|
||||
else:
|
||||
description_re = r"[0-9.]+-([0-9]+)-(g[0-9a-f]{7,}(?:-dirty)?)"
|
||||
match = re.fullmatch(description_re, description)
|
||||
if match is None:
|
||||
raise ValueError(f"Unexpected git description: {description}")
|
||||
distance, remainder = match.groups()
|
||||
remainder = remainder.replace("-", ".") # required by PEP 440
|
||||
return f"{tag}.dev{distance}+{remainder}"
|
||||
|
||||
# Avoid crashing if the development version cannot be determined.
|
||||
return f"{tag}.dev0+gunknown"
|
||||
|
||||
version = get_version(tag)
|
||||
|
||||
def get_commit(tag: str, version: str) -> str:
|
||||
# Extract commit from version, falling back to tag if not available.
|
||||
version_re = r"[0-9.]+\.dev[0-9]+\+g([0-9a-f]{7,}|unknown)(?:\.dirty)?"
|
||||
match = re.fullmatch(version_re, version)
|
||||
if match is None:
|
||||
raise ValueError(f"Unexpected version: {version}")
|
||||
(commit,) = match.groups()
|
||||
return tag if commit == "unknown" else commit
|
||||
|
||||
commit = get_commit(tag, version)
|
Loading…
Reference in New Issue
Block a user