Upgrade yt_dlp and download script

This commit is contained in:
2025-05-02 16:11:08 -05:00
parent 3a2e8eeb08
commit d68d9ce4f9
1194 changed files with 60099 additions and 44436 deletions

View File

@@ -1,6 +1,9 @@
# flake8: noqa: 401
# flake8: noqa: F401
import warnings
from .common import (
HEADRequest,
PATCHRequest,
PUTRequest,
Request,
RequestDirector,
@@ -11,3 +14,25 @@ from .common import (
# isort: split
# TODO: all request handlers should be safely imported
from . import _urllib
from ..utils import bug_reports_message
try:
from . import _requests
except ImportError:
pass
except Exception as e:
warnings.warn(f'Failed to import "requests" request handler: {e}' + bug_reports_message())
try:
from . import _websockets
except ImportError:
pass
except Exception as e:
warnings.warn(f'Failed to import "websockets" request handler: {e}' + bug_reports_message())
try:
from . import _curlcffi
except ImportError:
pass
except Exception as e:
warnings.warn(f'Failed to import "curl_cffi" request handler: {e}' + bug_reports_message())

View File

@@ -0,0 +1,296 @@
from __future__ import annotations
import io
import itertools
import math
import re
import urllib.parse
from ._helper import InstanceStoreMixin, select_proxy
from .common import (
Features,
Request,
Response,
register_preference,
register_rh,
)
from .exceptions import (
CertificateVerifyError,
HTTPError,
IncompleteRead,
ProxyError,
SSLError,
TransportError,
)
from .impersonate import ImpersonateRequestHandler, ImpersonateTarget
from ..dependencies import curl_cffi, certifi
from ..utils import int_or_none
if curl_cffi is None:
raise ImportError('curl_cffi is not installed')
curl_cffi_version = tuple(map(int, re.split(r'[^\d]+', curl_cffi.__version__)[:3]))
if curl_cffi_version != (0, 5, 10) and not (0, 10) <= curl_cffi_version:
curl_cffi._yt_dlp__version = f'{curl_cffi.__version__} (unsupported)'
raise ImportError('Only curl_cffi versions 0.5.10 and 0.10.x are supported')
import curl_cffi.requests
from curl_cffi.const import CurlECode, CurlOpt
class CurlCFFIResponseReader(io.IOBase):
def __init__(self, response: curl_cffi.requests.Response):
self._response = response
self._iterator = response.iter_content()
self._buffer = b''
self.bytes_read = 0
def readable(self):
return True
def read(self, size=None):
exception_raised = True
try:
while self._iterator and (size is None or len(self._buffer) < size):
chunk = next(self._iterator, None)
if chunk is None:
self._iterator = None
break
self._buffer += chunk
self.bytes_read += len(chunk)
if size is None:
size = len(self._buffer)
data = self._buffer[:size]
self._buffer = self._buffer[size:]
# "free" the curl instance if the response is fully read.
# curl_cffi doesn't do this automatically and only allows one open response per thread
if not self._iterator and not self._buffer:
self.close()
exception_raised = False
return data
finally:
if exception_raised:
self.close()
def close(self):
if not self.closed:
self._response.close()
self._buffer = b''
super().close()
class CurlCFFIResponseAdapter(Response):
fp: CurlCFFIResponseReader
def __init__(self, response: curl_cffi.requests.Response):
super().__init__(
fp=CurlCFFIResponseReader(response),
headers=response.headers,
url=response.url,
status=response.status_code)
def read(self, amt=None):
try:
return self.fp.read(amt)
except curl_cffi.requests.errors.RequestsError as e:
if e.code == CurlECode.PARTIAL_FILE:
content_length = e.response and int_or_none(e.response.headers.get('Content-Length'))
raise IncompleteRead(
partial=self.fp.bytes_read,
expected=content_length - self.fp.bytes_read if content_length is not None else None,
cause=e) from e
raise TransportError(cause=e) from e
# See: https://github.com/lexiforest/curl_cffi?tab=readme-ov-file#supported-impersonate-browsers
# https://github.com/lexiforest/curl-impersonate?tab=readme-ov-file#supported-browsers
BROWSER_TARGETS: dict[tuple[int, ...], dict[str, ImpersonateTarget]] = {
(0, 5): {
'chrome99': ImpersonateTarget('chrome', '99', 'windows', '10'),
'chrome99_android': ImpersonateTarget('chrome', '99', 'android', '12'),
'chrome100': ImpersonateTarget('chrome', '100', 'windows', '10'),
'chrome101': ImpersonateTarget('chrome', '101', 'windows', '10'),
'chrome104': ImpersonateTarget('chrome', '104', 'windows', '10'),
'chrome107': ImpersonateTarget('chrome', '107', 'windows', '10'),
'chrome110': ImpersonateTarget('chrome', '110', 'windows', '10'),
'edge99': ImpersonateTarget('edge', '99', 'windows', '10'),
'edge101': ImpersonateTarget('edge', '101', 'windows', '10'),
'safari15_3': ImpersonateTarget('safari', '15.3', 'macos', '11'),
'safari15_5': ImpersonateTarget('safari', '15.5', 'macos', '12'),
},
(0, 7): {
'chrome116': ImpersonateTarget('chrome', '116', 'windows', '10'),
'chrome119': ImpersonateTarget('chrome', '119', 'macos', '14'),
'chrome120': ImpersonateTarget('chrome', '120', 'macos', '14'),
'chrome123': ImpersonateTarget('chrome', '123', 'macos', '14'),
'chrome124': ImpersonateTarget('chrome', '124', 'macos', '14'),
'safari17_0': ImpersonateTarget('safari', '17.0', 'macos', '14'),
'safari17_2_ios': ImpersonateTarget('safari', '17.2', 'ios', '17.2'),
},
(0, 9): {
'safari15_3': ImpersonateTarget('safari', '15.3', 'macos', '14'),
'safari15_5': ImpersonateTarget('safari', '15.5', 'macos', '14'),
'chrome119': ImpersonateTarget('chrome', '119', 'macos', '14'),
'chrome120': ImpersonateTarget('chrome', '120', 'macos', '14'),
'chrome123': ImpersonateTarget('chrome', '123', 'macos', '14'),
'chrome124': ImpersonateTarget('chrome', '124', 'macos', '14'),
'chrome131': ImpersonateTarget('chrome', '131', 'macos', '14'),
'chrome131_android': ImpersonateTarget('chrome', '131', 'android', '14'),
'chrome133a': ImpersonateTarget('chrome', '133', 'macos', '15'),
'firefox133': ImpersonateTarget('firefox', '133', 'macos', '14'),
'safari18_0': ImpersonateTarget('safari', '18.0', 'macos', '15'),
'safari18_0_ios': ImpersonateTarget('safari', '18.0', 'ios', '18.0'),
},
(0, 10): {
'firefox135': ImpersonateTarget('firefox', '135', 'macos', '14'),
},
}
@register_rh
class CurlCFFIRH(ImpersonateRequestHandler, InstanceStoreMixin):
RH_NAME = 'curl_cffi'
_SUPPORTED_URL_SCHEMES = ('http', 'https')
_SUPPORTED_FEATURES = (Features.NO_PROXY, Features.ALL_PROXY)
_SUPPORTED_PROXY_SCHEMES = ('http', 'https', 'socks4', 'socks4a', 'socks5', 'socks5h')
_SUPPORTED_IMPERSONATE_TARGET_MAP = {
target: name if curl_cffi_version >= (0, 9) else curl_cffi.requests.BrowserType[name]
for name, target in dict(sorted(itertools.chain.from_iterable(
targets.items()
for version, targets in BROWSER_TARGETS.items()
if curl_cffi_version >= version
), key=lambda x: (
# deprioritize mobile targets since they give very different behavior
x[1].os not in ('ios', 'android'),
# prioritize edge < firefox < safari < chrome
('edge', 'firefox', 'safari', 'chrome').index(x[1].client),
# prioritize newest version
float(x[1].version) if x[1].version else 0,
# group by os name
x[1].os,
), reverse=True)).items()
}
def _create_instance(self, cookiejar=None):
return curl_cffi.requests.Session(cookies=cookiejar)
def _check_extensions(self, extensions):
super()._check_extensions(extensions)
extensions.pop('impersonate', None)
extensions.pop('cookiejar', None)
extensions.pop('timeout', None)
# CurlCFFIRH ignores legacy ssl options currently.
# Impersonation generally uses a looser SSL configuration than urllib/requests.
extensions.pop('legacy_ssl', None)
def send(self, request: Request) -> Response:
target = self._get_request_target(request)
try:
response = super().send(request)
except HTTPError as e:
e.response.extensions['impersonate'] = target
raise
response.extensions['impersonate'] = target
return response
def _send(self, request: Request):
max_redirects_exceeded = False
session: curl_cffi.requests.Session = self._get_instance(
cookiejar=self._get_cookiejar(request) if 'cookie' not in request.headers else None)
if self.verbose:
session.curl.setopt(CurlOpt.VERBOSE, 1)
proxies = self._get_proxies(request)
if 'no' in proxies:
session.curl.setopt(CurlOpt.NOPROXY, proxies['no'])
proxies.pop('no', None)
# curl doesn't support per protocol proxies, so we select the one that matches the request protocol
proxy = select_proxy(request.url, proxies=proxies)
if proxy:
session.curl.setopt(CurlOpt.PROXY, proxy)
scheme = urllib.parse.urlparse(request.url).scheme.lower()
if scheme != 'http':
# Enable HTTP CONNECT for HTTPS urls.
# Don't use CONNECT for http for compatibility with urllib behaviour.
# See: https://curl.se/libcurl/c/CURLOPT_HTTPPROXYTUNNEL.html
session.curl.setopt(CurlOpt.HTTPPROXYTUNNEL, 1)
# curl_cffi does not currently set these for proxies
session.curl.setopt(CurlOpt.PROXY_CAINFO, certifi.where())
if not self.verify:
session.curl.setopt(CurlOpt.PROXY_SSL_VERIFYPEER, 0)
session.curl.setopt(CurlOpt.PROXY_SSL_VERIFYHOST, 0)
headers = self._get_impersonate_headers(request)
if self._client_cert:
session.curl.setopt(CurlOpt.SSLCERT, self._client_cert['client_certificate'])
client_certificate_key = self._client_cert.get('client_certificate_key')
client_certificate_password = self._client_cert.get('client_certificate_password')
if client_certificate_key:
session.curl.setopt(CurlOpt.SSLKEY, client_certificate_key)
if client_certificate_password:
session.curl.setopt(CurlOpt.KEYPASSWD, client_certificate_password)
timeout = self._calculate_timeout(request)
# set CURLOPT_LOW_SPEED_LIMIT and CURLOPT_LOW_SPEED_TIME to act as a read timeout. [1]
# This is required only for 0.5.10 [2]
# Note: CURLOPT_LOW_SPEED_TIME is in seconds, so we need to round up to the nearest second. [3]
# [1] https://unix.stackexchange.com/a/305311
# [2] https://github.com/yifeikong/curl_cffi/issues/156
# [3] https://curl.se/libcurl/c/CURLOPT_LOW_SPEED_TIME.html
session.curl.setopt(CurlOpt.LOW_SPEED_LIMIT, 1) # 1 byte per second
session.curl.setopt(CurlOpt.LOW_SPEED_TIME, math.ceil(timeout))
try:
curl_response = session.request(
method=request.method,
url=request.url,
headers=headers,
data=request.data,
verify=self.verify,
max_redirects=5,
timeout=(timeout, timeout),
impersonate=self._SUPPORTED_IMPERSONATE_TARGET_MAP.get(
self._get_request_target(request)),
interface=self.source_address,
stream=True,
)
except curl_cffi.requests.errors.RequestsError as e:
if e.code == CurlECode.PEER_FAILED_VERIFICATION:
raise CertificateVerifyError(cause=e) from e
elif e.code == CurlECode.SSL_CONNECT_ERROR:
raise SSLError(cause=e) from e
elif e.code == CurlECode.TOO_MANY_REDIRECTS:
max_redirects_exceeded = True
curl_response = e.response
elif (
e.code == CurlECode.PROXY
or (e.code == CurlECode.RECV_ERROR and 'CONNECT' in str(e))
):
raise ProxyError(cause=e) from e
else:
raise TransportError(cause=e) from e
response = CurlCFFIResponseAdapter(curl_response)
if not 200 <= response.status < 300:
raise HTTPError(response, redirect_loop=max_redirects_exceeded)
return response
@register_preference(CurlCFFIRH)
def curl_cffi_preference(rh, request):
return -100

View File

@@ -2,15 +2,17 @@ from __future__ import annotations
import contextlib
import functools
import os
import socket
import ssl
import sys
import typing
import urllib.parse
import urllib.request
from .exceptions import RequestError, UnsupportedRequest
from .exceptions import RequestError
from ..dependencies import certifi
from ..socks import ProxyType
from ..socks import ProxyType, sockssocket
from ..utils import format_field, traverse_obj
if typing.TYPE_CHECKING:
@@ -120,6 +122,9 @@ def make_ssl_context(
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
context.check_hostname = verify
context.verify_mode = ssl.CERT_REQUIRED if verify else ssl.CERT_NONE
# OpenSSL 1.1.1+ Python 3.8+ keylog file
if hasattr(context, 'keylog_filename'):
context.keylog_filename = os.environ.get('SSLKEYLOGFILE') or None
# Some servers may reject requests if ALPN extension is not sent. See:
# https://github.com/python/cpython/issues/85140
@@ -201,8 +206,82 @@ def wrap_request_errors(func):
def wrapper(self, *args, **kwargs):
try:
return func(self, *args, **kwargs)
except UnsupportedRequest as e:
except RequestError as e:
if e.handler is None:
e.handler = self
raise
return wrapper
def _socket_connect(ip_addr, timeout, source_address):
af, socktype, proto, canonname, sa = ip_addr
sock = socket.socket(af, socktype, proto)
try:
if timeout is not socket._GLOBAL_DEFAULT_TIMEOUT:
sock.settimeout(timeout)
if source_address:
sock.bind(source_address)
sock.connect(sa)
return sock
except OSError:
sock.close()
raise
def create_socks_proxy_socket(dest_addr, proxy_args, proxy_ip_addr, timeout, source_address):
af, socktype, proto, canonname, sa = proxy_ip_addr
sock = sockssocket(af, socktype, proto)
try:
connect_proxy_args = proxy_args.copy()
connect_proxy_args.update({'addr': sa[0], 'port': sa[1]})
sock.setproxy(**connect_proxy_args)
if timeout is not socket._GLOBAL_DEFAULT_TIMEOUT:
sock.settimeout(timeout)
if source_address:
sock.bind(source_address)
sock.connect(dest_addr)
return sock
except OSError:
sock.close()
raise
def create_connection(
address,
timeout=socket._GLOBAL_DEFAULT_TIMEOUT,
source_address=None,
*,
_create_socket_func=_socket_connect,
):
# Work around socket.create_connection() which tries all addresses from getaddrinfo() including IPv6.
# This filters the addresses based on the given source_address.
# Based on: https://github.com/python/cpython/blob/main/Lib/socket.py#L810
host, port = address
ip_addrs = socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM)
if not ip_addrs:
raise OSError('getaddrinfo returns an empty list')
if source_address is not None:
af = socket.AF_INET if ':' not in source_address[0] else socket.AF_INET6
ip_addrs = [addr for addr in ip_addrs if addr[0] == af]
if not ip_addrs:
raise OSError(
f'No remote IPv{4 if af == socket.AF_INET else 6} addresses available for connect. '
f'Can\'t use "{source_address[0]}" as source address')
err = None
for ip_addr in ip_addrs:
try:
sock = _create_socket_func(ip_addr, timeout, source_address)
# Explicitly break __traceback__ reference cycle
# https://bugs.python.org/issue36820
err = None
return sock
except OSError as e:
err = e
try:
raise err
finally:
# Explicitly break __traceback__ reference cycle
# https://bugs.python.org/issue36820
err = None

View File

@@ -0,0 +1,426 @@
from __future__ import annotations
import contextlib
import functools
import http.client
import logging
import re
import socket
import warnings
from ..dependencies import brotli, requests, urllib3
from ..utils import bug_reports_message, int_or_none, variadic
from ..utils.networking import normalize_url
if requests is None:
raise ImportError('requests module is not installed')
if urllib3 is None:
raise ImportError('urllib3 module is not installed')
urllib3_version = tuple(int_or_none(x, default=0) for x in urllib3.__version__.split('.'))
if urllib3_version < (1, 26, 17):
urllib3._yt_dlp__version = f'{urllib3.__version__} (unsupported)'
raise ImportError('Only urllib3 >= 1.26.17 is supported')
if requests.__build__ < 0x023202:
requests._yt_dlp__version = f'{requests.__version__} (unsupported)'
raise ImportError('Only requests >= 2.32.2 is supported')
import requests.adapters
import requests.utils
import urllib3.connection
import urllib3.exceptions
import urllib3.util
from ._helper import (
InstanceStoreMixin,
add_accept_encoding_header,
create_connection,
create_socks_proxy_socket,
get_redirect_method,
make_socks_proxy_opts,
select_proxy,
)
from .common import (
Features,
RequestHandler,
Response,
register_preference,
register_rh,
)
from .exceptions import (
CertificateVerifyError,
HTTPError,
IncompleteRead,
ProxyError,
RequestError,
SSLError,
TransportError,
)
from ..socks import ProxyError as SocksProxyError
SUPPORTED_ENCODINGS = [
'gzip', 'deflate',
]
if brotli is not None:
SUPPORTED_ENCODINGS.append('br')
'''
Override urllib3's behavior to not convert lower-case percent-encoded characters
to upper-case during url normalization process.
RFC3986 defines that the lower or upper case percent-encoded hexidecimal characters are equivalent
and normalizers should convert them to uppercase for consistency [1].
However, some sites may have an incorrect implementation where they provide
a percent-encoded url that is then compared case-sensitively.[2]
While this is a very rare case, since urllib does not do this normalization step, it
is best to avoid it in requests too for compatability reasons.
1: https://tools.ietf.org/html/rfc3986#section-2.1
2: https://github.com/streamlink/streamlink/pull/4003
'''
class Urllib3PercentREOverride:
def __init__(self, r: re.Pattern):
self.re = r
# pass through all other attribute calls to the original re
def __getattr__(self, item):
return self.re.__getattribute__(item)
def subn(self, repl, string, *args, **kwargs):
return string, self.re.subn(repl, string, *args, **kwargs)[1]
# urllib3 >= 1.25.8 uses subn:
# https://github.com/urllib3/urllib3/commit/a2697e7c6b275f05879b60f593c5854a816489f0
import urllib3.util.url
if hasattr(urllib3.util.url, 'PERCENT_RE'):
urllib3.util.url.PERCENT_RE = Urllib3PercentREOverride(urllib3.util.url.PERCENT_RE)
elif hasattr(urllib3.util.url, '_PERCENT_RE'): # urllib3 >= 2.0.0
urllib3.util.url._PERCENT_RE = Urllib3PercentREOverride(urllib3.util.url._PERCENT_RE)
else:
warnings.warn('Failed to patch PERCENT_RE in urllib3 (does the attribute exist?)' + bug_reports_message())
'''
Workaround for issue in urllib.util.ssl_.py: ssl_wrap_context does not pass
server_hostname to SSLContext.wrap_socket if server_hostname is an IP,
however this is an issue because we set check_hostname to True in our SSLContext.
Monkey-patching IS_SECURETRANSPORT forces ssl_wrap_context to pass server_hostname regardless.
This has been fixed in urllib3 2.0+.
See: https://github.com/urllib3/urllib3/issues/517
'''
if urllib3_version < (2, 0, 0):
with contextlib.suppress(Exception):
urllib3.util.IS_SECURETRANSPORT = urllib3.util.ssl_.IS_SECURETRANSPORT = True
# Requests will not automatically handle no_proxy by default
# due to buggy no_proxy handling with proxy dict [1].
# 1. https://github.com/psf/requests/issues/5000
requests.adapters.select_proxy = select_proxy
class RequestsResponseAdapter(Response):
def __init__(self, res: requests.models.Response):
super().__init__(
fp=res.raw, headers=res.headers, url=res.url,
status=res.status_code, reason=res.reason)
self._requests_response = res
def read(self, amt: int | None = None):
try:
# Interact with urllib3 response directly.
return self.fp.read(amt, decode_content=True)
# See urllib3.response.HTTPResponse.read() for exceptions raised on read
except urllib3.exceptions.SSLError as e:
raise SSLError(cause=e) from e
except urllib3.exceptions.ProtocolError as e:
# IncompleteRead is always contained within ProtocolError
# See urllib3.response.HTTPResponse._error_catcher()
ir_err = next(
(err for err in (e.__context__, e.__cause__, *variadic(e.args))
if isinstance(err, http.client.IncompleteRead)), None)
if ir_err is not None:
# `urllib3.exceptions.IncompleteRead` is subclass of `http.client.IncompleteRead`
# but uses an `int` for its `partial` property.
partial = ir_err.partial if isinstance(ir_err.partial, int) else len(ir_err.partial)
raise IncompleteRead(partial=partial, expected=ir_err.expected) from e
raise TransportError(cause=e) from e
except urllib3.exceptions.HTTPError as e:
# catch-all for any other urllib3 response exceptions
raise TransportError(cause=e) from e
class RequestsHTTPAdapter(requests.adapters.HTTPAdapter):
def __init__(self, ssl_context=None, proxy_ssl_context=None, source_address=None, **kwargs):
self._pm_args = {}
if ssl_context:
self._pm_args['ssl_context'] = ssl_context
if source_address:
self._pm_args['source_address'] = (source_address, 0)
self._proxy_ssl_context = proxy_ssl_context or ssl_context
super().__init__(**kwargs)
def init_poolmanager(self, *args, **kwargs):
return super().init_poolmanager(*args, **kwargs, **self._pm_args)
def proxy_manager_for(self, proxy, **proxy_kwargs):
extra_kwargs = {}
if not proxy.lower().startswith('socks') and self._proxy_ssl_context:
extra_kwargs['proxy_ssl_context'] = self._proxy_ssl_context
return super().proxy_manager_for(proxy, **proxy_kwargs, **self._pm_args, **extra_kwargs)
# Skip `requests` internal verification; we use our own SSLContext
def cert_verify(*args, **kwargs):
pass
# requests 2.32.2+: Reimplementation without `_urllib3_request_context`
def get_connection_with_tls_context(self, request, verify, proxies=None, cert=None):
url = urllib3.util.parse_url(request.url).url
manager = self.poolmanager
if proxy := select_proxy(url, proxies):
manager = self.proxy_manager_for(proxy)
return manager.connection_from_url(url)
class RequestsSession(requests.sessions.Session):
"""
Ensure unified redirect method handling with our urllib redirect handler.
"""
def rebuild_method(self, prepared_request, response):
new_method = get_redirect_method(prepared_request.method, response.status_code)
# HACK: requests removes headers/body on redirect unless code was a 307/308.
if new_method == prepared_request.method:
response._real_status_code = response.status_code
response.status_code = 308
prepared_request.method = new_method
# Requests fails to resolve dot segments on absolute redirect locations
# See: https://github.com/yt-dlp/yt-dlp/issues/9020
prepared_request.url = normalize_url(prepared_request.url)
def rebuild_auth(self, prepared_request, response):
# HACK: undo status code change from rebuild_method, if applicable.
# rebuild_auth runs after requests would remove headers/body based on status code
if hasattr(response, '_real_status_code'):
response.status_code = response._real_status_code
del response._real_status_code
return super().rebuild_auth(prepared_request, response)
class Urllib3LoggingFilter(logging.Filter):
def filter(self, record):
# Ignore HTTP request messages since HTTPConnection prints those
return record.msg != '%s://%s:%s "%s %s %s" %s %s'
class Urllib3LoggingHandler(logging.Handler):
"""Redirect urllib3 logs to our logger"""
def __init__(self, logger, *args, **kwargs):
super().__init__(*args, **kwargs)
self._logger = logger
def emit(self, record):
try:
msg = self.format(record)
if record.levelno >= logging.ERROR:
self._logger.error(msg)
else:
self._logger.stdout(msg)
except Exception:
self.handleError(record)
@register_rh
class RequestsRH(RequestHandler, InstanceStoreMixin):
"""Requests RequestHandler
https://github.com/psf/requests
"""
_SUPPORTED_URL_SCHEMES = ('http', 'https')
_SUPPORTED_ENCODINGS = tuple(SUPPORTED_ENCODINGS)
_SUPPORTED_PROXY_SCHEMES = ('http', 'https', 'socks4', 'socks4a', 'socks5', 'socks5h')
_SUPPORTED_FEATURES = (Features.NO_PROXY, Features.ALL_PROXY)
RH_NAME = 'requests'
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Forward urllib3 debug messages to our logger
logger = logging.getLogger('urllib3')
self.__logging_handler = Urllib3LoggingHandler(logger=self._logger)
self.__logging_handler.setFormatter(logging.Formatter('requests: %(message)s'))
self.__logging_handler.addFilter(Urllib3LoggingFilter())
logger.addHandler(self.__logging_handler)
# TODO: Use a logger filter to suppress pool reuse warning instead
logger.setLevel(logging.ERROR)
if self.verbose:
# Setting this globally is not ideal, but is easier than hacking with urllib3.
# It could technically be problematic for scripts embedding yt-dlp.
# However, it is unlikely debug traffic is used in that context in a way this will cause problems.
urllib3.connection.HTTPConnection.debuglevel = 1
logger.setLevel(logging.DEBUG)
# this is expected if we are using --no-check-certificate
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
def close(self):
self._clear_instances()
# Remove the logging handler that contains a reference to our logger
# See: https://github.com/yt-dlp/yt-dlp/issues/8922
logging.getLogger('urllib3').removeHandler(self.__logging_handler)
def _check_extensions(self, extensions):
super()._check_extensions(extensions)
extensions.pop('cookiejar', None)
extensions.pop('timeout', None)
extensions.pop('legacy_ssl', None)
extensions.pop('keep_header_casing', None)
def _create_instance(self, cookiejar, legacy_ssl_support=None):
session = RequestsSession()
http_adapter = RequestsHTTPAdapter(
ssl_context=self._make_sslcontext(legacy_ssl_support=legacy_ssl_support),
source_address=self.source_address,
max_retries=urllib3.util.retry.Retry(False),
)
session.adapters.clear()
session.headers = requests.models.CaseInsensitiveDict({'Connection': 'keep-alive'})
session.mount('https://', http_adapter)
session.mount('http://', http_adapter)
session.cookies = cookiejar
session.trust_env = False # no need, we already load proxies from env
return session
def _prepare_headers(self, _, headers):
add_accept_encoding_header(headers, SUPPORTED_ENCODINGS)
def _send(self, request):
headers = self._get_headers(request)
max_redirects_exceeded = False
session = self._get_instance(
cookiejar=self._get_cookiejar(request),
legacy_ssl_support=request.extensions.get('legacy_ssl'),
)
try:
requests_res = session.request(
method=request.method,
url=request.url,
data=request.data,
headers=headers,
timeout=self._calculate_timeout(request),
proxies=self._get_proxies(request),
allow_redirects=True,
stream=True,
)
except requests.exceptions.TooManyRedirects as e:
max_redirects_exceeded = True
requests_res = e.response
except requests.exceptions.SSLError as e:
if 'CERTIFICATE_VERIFY_FAILED' in str(e):
raise CertificateVerifyError(cause=e) from e
raise SSLError(cause=e) from e
except requests.exceptions.ProxyError as e:
raise ProxyError(cause=e) from e
except (requests.exceptions.ConnectionError, requests.exceptions.Timeout) as e:
raise TransportError(cause=e) from e
except urllib3.exceptions.HTTPError as e:
# Catch any urllib3 exceptions that may leak through
raise TransportError(cause=e) from e
except requests.exceptions.RequestException as e:
# Miscellaneous Requests exceptions. May not necessary be network related e.g. InvalidURL
raise RequestError(cause=e) from e
res = RequestsResponseAdapter(requests_res)
if not 200 <= res.status < 300:
raise HTTPError(res, redirect_loop=max_redirects_exceeded)
return res
@register_preference(RequestsRH)
def requests_preference(rh, request):
return 100
# Use our socks proxy implementation with requests to avoid an extra dependency.
class SocksHTTPConnection(urllib3.connection.HTTPConnection):
def __init__(self, _socks_options, *args, **kwargs): # must use _socks_options to pass PoolKey checks
self._proxy_args = _socks_options
super().__init__(*args, **kwargs)
def _new_conn(self):
try:
return create_connection(
address=(self._proxy_args['addr'], self._proxy_args['port']),
timeout=self.timeout,
source_address=self.source_address,
_create_socket_func=functools.partial(
create_socks_proxy_socket, (self.host, self.port), self._proxy_args))
except (socket.timeout, TimeoutError) as e:
raise urllib3.exceptions.ConnectTimeoutError(
self, f'Connection to {self.host} timed out. (connect timeout={self.timeout})') from e
except SocksProxyError as e:
raise urllib3.exceptions.ProxyError(str(e), e) from e
except OSError as e:
raise urllib3.exceptions.NewConnectionError(
self, f'Failed to establish a new connection: {e}') from e
class SocksHTTPSConnection(SocksHTTPConnection, urllib3.connection.HTTPSConnection):
pass
class SocksHTTPConnectionPool(urllib3.HTTPConnectionPool):
ConnectionCls = SocksHTTPConnection
class SocksHTTPSConnectionPool(urllib3.HTTPSConnectionPool):
ConnectionCls = SocksHTTPSConnection
class SocksProxyManager(urllib3.PoolManager):
def __init__(self, socks_proxy, username=None, password=None, num_pools=10, headers=None, **connection_pool_kw):
connection_pool_kw['_socks_options'] = make_socks_proxy_opts(socks_proxy)
super().__init__(num_pools, headers, **connection_pool_kw)
self.pool_classes_by_scheme = {
'http': SocksHTTPConnectionPool,
'https': SocksHTTPSConnectionPool,
}
requests.adapters.SOCKSProxyManager = SocksProxyManager

View File

@@ -1,10 +1,8 @@
from __future__ import annotations
import functools
import gzip
import http.client
import io
import socket
import ssl
import urllib.error
import urllib.parse
@@ -24,6 +22,8 @@ from urllib.request import (
from ._helper import (
InstanceStoreMixin,
add_accept_encoding_header,
create_connection,
create_socks_proxy_socket,
get_redirect_method,
make_socks_proxy_opts,
select_proxy,
@@ -40,7 +40,6 @@ from .exceptions import (
)
from ..dependencies import brotli
from ..socks import ProxyError as SocksProxyError
from ..socks import sockssocket
from ..utils import update_url_query
from ..utils.networking import normalize_url
@@ -55,44 +54,10 @@ if brotli:
def _create_http_connection(http_class, source_address, *args, **kwargs):
hc = http_class(*args, **kwargs)
if hasattr(hc, '_create_connection'):
hc._create_connection = create_connection
if source_address is not None:
# This is to workaround _create_connection() from socket where it will try all
# address data from getaddrinfo() including IPv6. This filters the result from
# getaddrinfo() based on the source_address value.
# This is based on the cpython socket.create_connection() function.
# https://github.com/python/cpython/blob/master/Lib/socket.py#L691
def _create_connection(address, timeout=socket._GLOBAL_DEFAULT_TIMEOUT, source_address=None):
host, port = address
err = None
addrs = socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM)
af = socket.AF_INET if '.' in source_address[0] else socket.AF_INET6
ip_addrs = [addr for addr in addrs if addr[0] == af]
if addrs and not ip_addrs:
ip_version = 'v4' if af == socket.AF_INET else 'v6'
raise OSError(
"No remote IP%s addresses available for connect, can't use '%s' as source address"
% (ip_version, source_address[0]))
for res in ip_addrs:
af, socktype, proto, canonname, sa = res
sock = None
try:
sock = socket.socket(af, socktype, proto)
if timeout is not socket._GLOBAL_DEFAULT_TIMEOUT:
sock.settimeout(timeout)
sock.bind(source_address)
sock.connect(sa)
err = None # Explicitly break reference cycle
return sock
except OSError as _:
err = _
if sock is not None:
sock.close()
if err is not None:
raise err
else:
raise OSError('getaddrinfo returns an empty list')
if hasattr(hc, '_create_connection'):
hc._create_connection = _create_connection
hc.source_address = (source_address, 0)
return hc
@@ -155,20 +120,11 @@ class HTTPHandler(urllib.request.AbstractHTTPHandler):
@staticmethod
def gz(data):
gz = gzip.GzipFile(fileobj=io.BytesIO(data), mode='rb')
try:
return gz.read()
except OSError as original_oserror:
# There may be junk add the end of the file
# See http://stackoverflow.com/q/4928560/35070 for details
for i in range(1, 1024):
try:
gz = gzip.GzipFile(fileobj=io.BytesIO(data[:-i]), mode='rb')
return gz.read()
except OSError:
continue
else:
raise original_oserror
# There may be junk added the end of the file
# We ignore it by only ever decoding a single gzip payload
if not data:
return data
return zlib.decompress(data, wbits=zlib.MAX_WBITS | 16)
def http_request(self, req):
# According to RFC 3986, URLs can not contain non-ASCII characters, however this is not
@@ -211,7 +167,7 @@ class HTTPHandler(urllib.request.AbstractHTTPHandler):
if 300 <= resp.code < 400:
location = resp.headers.get('Location')
if location:
# As of RFC 2616 default charset is iso-8859-1 that is respected by python 3
# As of RFC 2616 default charset is iso-8859-1 that is respected by Python 3
location = location.encode('iso-8859-1').decode()
location_escaped = normalize_url(location)
if location != location_escaped:
@@ -230,13 +186,15 @@ def make_socks_conn_class(base_class, socks_proxy):
proxy_args = make_socks_proxy_opts(socks_proxy)
class SocksConnection(base_class):
def connect(self):
self.sock = sockssocket()
self.sock.setproxy(**proxy_args)
if type(self.timeout) in (int, float): # noqa: E721
self.sock.settimeout(self.timeout)
self.sock.connect((self.host, self.port))
_create_connection = create_connection
def connect(self):
self.sock = create_connection(
(proxy_args['addr'], proxy_args['port']),
timeout=self.timeout,
source_address=self.source_address,
_create_socket_func=functools.partial(
create_socks_proxy_socket, (self.host, self.port), proxy_args))
if isinstance(self, http.client.HTTPSConnection):
self.sock = self._context.wrap_socket(self.sock, server_hostname=self.host)
@@ -288,8 +246,8 @@ class ProxyHandler(urllib.request.BaseHandler):
def __init__(self, proxies=None):
self.proxies = proxies
# Set default handlers
for type in ('http', 'https', 'ftp'):
setattr(self, '%s_open' % type, lambda r, meth=self.proxy_open: meth(r))
for scheme in ('http', 'https', 'ftp'):
setattr(self, f'{scheme}_open', lambda r, meth=self.proxy_open: meth(r))
def proxy_open(self, req):
proxy = select_proxy(req.get_full_url(), self.proxies)
@@ -365,7 +323,7 @@ def handle_sslerror(e: ssl.SSLError):
def handle_response_read_exceptions(e):
if isinstance(e, http.client.IncompleteRead):
raise IncompleteRead(partial=e.partial, cause=e, expected=e.expected) from e
raise IncompleteRead(partial=len(e.partial), cause=e, expected=e.expected) from e
elif isinstance(e, ssl.SSLError):
handle_sslerror(e)
elif isinstance(e, (OSError, EOFError, http.client.HTTPException, *CONTENT_DECODE_ERRORS)):
@@ -390,14 +348,15 @@ class UrllibRH(RequestHandler, InstanceStoreMixin):
super()._check_extensions(extensions)
extensions.pop('cookiejar', None)
extensions.pop('timeout', None)
extensions.pop('legacy_ssl', None)
def _create_instance(self, proxies, cookiejar):
def _create_instance(self, proxies, cookiejar, legacy_ssl_support=None):
opener = urllib.request.OpenerDirector()
handlers = [
ProxyHandler(proxies),
HTTPHandler(
debuglevel=int(bool(self.verbose)),
context=self._make_sslcontext(),
context=self._make_sslcontext(legacy_ssl_support=legacy_ssl_support),
source_address=self.source_address),
HTTPCookieProcessor(cookiejar),
DataHandler(),
@@ -420,26 +379,29 @@ class UrllibRH(RequestHandler, InstanceStoreMixin):
opener.addheaders = []
return opener
def _send(self, request):
headers = self._merge_headers(request.headers)
def _prepare_headers(self, _, headers):
add_accept_encoding_header(headers, SUPPORTED_ENCODINGS)
def _send(self, request):
headers = self._get_headers(request)
urllib_req = urllib.request.Request(
url=request.url,
data=request.data,
headers=dict(headers),
method=request.method
headers=headers,
method=request.method,
)
opener = self._get_instance(
proxies=request.proxies or self.proxies,
cookiejar=request.extensions.get('cookiejar') or self.cookiejar
proxies=self._get_proxies(request),
cookiejar=self._get_cookiejar(request),
legacy_ssl_support=request.extensions.get('legacy_ssl'),
)
try:
res = opener.open(urllib_req, timeout=float(request.extensions.get('timeout') or self.timeout))
res = opener.open(urllib_req, timeout=self._calculate_timeout(request))
except urllib.error.HTTPError as e:
if isinstance(e.fp, (http.client.HTTPResponse, urllib.response.addinfourl)):
# Prevent file object from being closed when urllib.error.HTTPError is destroyed.
e._closer.file = None
e._closer.close_called = True
raise HTTPError(UrllibResponseAdapter(e.fp), redirect_loop='redirect error' in str(e)) from e
raise # unexpected
except urllib.error.URLError as e:

View File

@@ -0,0 +1,189 @@
from __future__ import annotations
import contextlib
import functools
import io
import logging
import ssl
import sys
from ._helper import (
create_connection,
create_socks_proxy_socket,
make_socks_proxy_opts,
select_proxy,
)
from .common import Features, Response, register_rh
from .exceptions import (
CertificateVerifyError,
HTTPError,
ProxyError,
RequestError,
SSLError,
TransportError,
)
from .websocket import WebSocketRequestHandler, WebSocketResponse
from ..dependencies import websockets
from ..socks import ProxyError as SocksProxyError
from ..utils import int_or_none
if not websockets:
raise ImportError('websockets is not installed')
import websockets.version
websockets_version = tuple(map(int_or_none, websockets.version.version.split('.')))
if websockets_version < (13, 0):
websockets._yt_dlp__version = f'{websockets.version.version} (unsupported)'
raise ImportError('Only websockets>=13.0 is supported')
import websockets.sync.client
from websockets.uri import parse_uri
# In websockets Connection, recv_exc and recv_events_exc are defined
# after the recv events handler thread is started [1].
# On our CI using PyPy, in some cases a race condition may occur
# where the recv events handler thread tries to use these attributes before they are defined [2].
# 1: https://github.com/python-websockets/websockets/blame/de768cf65e7e2b1a3b67854fb9e08816a5ff7050/src/websockets/sync/connection.py#L93
# 2: "AttributeError: 'ClientConnection' object has no attribute 'recv_events_exc'. Did you mean: 'recv_events'?"
import websockets.sync.connection # isort: split
with contextlib.suppress(Exception):
websockets.sync.connection.Connection.recv_exc = None
class WebsocketsResponseAdapter(WebSocketResponse):
def __init__(self, ws: websockets.sync.client.ClientConnection, url):
super().__init__(
fp=io.BytesIO(ws.response.body or b''),
url=url,
headers=ws.response.headers,
status=ws.response.status_code,
reason=ws.response.reason_phrase,
)
self._ws = ws
def close(self):
self._ws.close()
super().close()
def send(self, message):
# https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.send
try:
return self._ws.send(message)
except (websockets.exceptions.WebSocketException, RuntimeError, TimeoutError) as e:
raise TransportError(cause=e) from e
except SocksProxyError as e:
raise ProxyError(cause=e) from e
except TypeError as e:
raise RequestError(cause=e) from e
def recv(self):
# https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.recv
try:
return self._ws.recv()
except SocksProxyError as e:
raise ProxyError(cause=e) from e
except (websockets.exceptions.WebSocketException, RuntimeError, TimeoutError) as e:
raise TransportError(cause=e) from e
@register_rh
class WebsocketsRH(WebSocketRequestHandler):
"""
Websockets request handler
https://websockets.readthedocs.io
https://github.com/python-websockets/websockets
"""
_SUPPORTED_URL_SCHEMES = ('wss', 'ws')
_SUPPORTED_PROXY_SCHEMES = ('socks4', 'socks4a', 'socks5', 'socks5h')
_SUPPORTED_FEATURES = (Features.ALL_PROXY, Features.NO_PROXY)
RH_NAME = 'websockets'
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.__logging_handlers = {}
for name in ('websockets.client', 'websockets.server'):
logger = logging.getLogger(name)
handler = logging.StreamHandler(stream=sys.stdout)
handler.setFormatter(logging.Formatter(f'{self.RH_NAME}: %(message)s'))
self.__logging_handlers[name] = handler
logger.addHandler(handler)
if self.verbose:
logger.setLevel(logging.DEBUG)
def _check_extensions(self, extensions):
super()._check_extensions(extensions)
extensions.pop('timeout', None)
extensions.pop('cookiejar', None)
extensions.pop('legacy_ssl', None)
extensions.pop('keep_header_casing', None)
def close(self):
# Remove the logging handler that contains a reference to our logger
# See: https://github.com/yt-dlp/yt-dlp/issues/8922
for name, handler in self.__logging_handlers.items():
logging.getLogger(name).removeHandler(handler)
def _prepare_headers(self, request, headers):
if 'cookie' not in headers:
cookiejar = self._get_cookiejar(request)
cookie_header = cookiejar.get_cookie_header(request.url)
if cookie_header:
headers['cookie'] = cookie_header
def _send(self, request):
timeout = self._calculate_timeout(request)
headers = self._get_headers(request)
wsuri = parse_uri(request.url)
create_conn_kwargs = {
'source_address': (self.source_address, 0) if self.source_address else None,
'timeout': timeout,
}
proxy = select_proxy(request.url, self._get_proxies(request))
try:
if proxy:
socks_proxy_options = make_socks_proxy_opts(proxy)
sock = create_connection(
address=(socks_proxy_options['addr'], socks_proxy_options['port']),
_create_socket_func=functools.partial(
create_socks_proxy_socket, (wsuri.host, wsuri.port), socks_proxy_options),
**create_conn_kwargs,
)
else:
sock = create_connection(
address=(wsuri.host, wsuri.port),
**create_conn_kwargs,
)
ssl_ctx = self._make_sslcontext(legacy_ssl_support=request.extensions.get('legacy_ssl'))
conn = websockets.sync.client.connect(
sock=sock,
uri=request.url,
additional_headers=headers,
open_timeout=timeout,
user_agent_header=None,
ssl=ssl_ctx if wsuri.secure else None,
close_timeout=0, # not ideal, but prevents yt-dlp hanging
)
return WebsocketsResponseAdapter(conn, url=request.url)
# Exceptions as per https://websockets.readthedocs.io/en/stable/reference/sync/client.html
except SocksProxyError as e:
raise ProxyError(cause=e) from e
except websockets.exceptions.InvalidURI as e:
raise RequestError(cause=e) from e
except ssl.SSLCertVerificationError as e:
raise CertificateVerifyError(cause=e) from e
except ssl.SSLError as e:
raise SSLError(cause=e) from e
except websockets.exceptions.InvalidStatus as e:
raise HTTPError(
Response(
fp=io.BytesIO(e.response.body),
url=request.url,
headers=e.response.headers,
status=e.response.status_code,
reason=e.response.reason_phrase),
) from e
except (OSError, TimeoutError, websockets.exceptions.WebSocketException) as e:
raise TransportError(cause=e) from e

View File

@@ -31,6 +31,8 @@ from ..utils import (
)
from ..utils.networking import HTTPHeaderDict, normalize_url
DEFAULT_TIMEOUT = 20
def register_preference(*handlers: type[RequestHandler]):
assert all(issubclass(handler, RequestHandler) for handler in handlers)
@@ -68,6 +70,7 @@ class RequestDirector:
def close(self):
for handler in self.handlers.values():
handler.close()
self.handlers.clear()
def add_handler(self, handler: RequestHandler):
"""Add a handler. If a handler of the same RH_KEY exists, it will overwrite it"""
@@ -80,8 +83,8 @@ class RequestDirector:
rh: sum(pref(rh, request) for pref in self.preferences)
for rh in self.handlers.values()
}
self._print_verbose('Handler preferences for this request: %s' % ', '.join(
f'{rh.RH_NAME}={pref}' for rh, pref in preferences.items()))
self._print_verbose('Handler preferences for this request: {}'.format(', '.join(
f'{rh.RH_NAME}={pref}' for rh, pref in preferences.items())))
return sorted(self.handlers.values(), key=preferences.get, reverse=True)
def _print_verbose(self, msg):
@@ -202,6 +205,8 @@ class RequestHandler(abc.ABC):
The following extensions are defined for RequestHandler:
- `cookiejar`: Cookiejar to use for this request.
- `timeout`: socket timeout to use for this request.
- `legacy_ssl`: Enable legacy SSL options for this request. See legacy_ssl_support.
- `keep_header_casing`: Keep the casing of headers when sending the request.
To enable these, add extensions.pop('<extension>', None) to _check_extensions
Apart from the url protocol, proxies dict may contain the following keys:
@@ -221,11 +226,11 @@ class RequestHandler(abc.ABC):
headers: HTTPHeaderDict = None,
cookiejar: YoutubeDLCookieJar = None,
timeout: float | int | None = None,
proxies: dict = None,
source_address: str = None,
proxies: dict | None = None,
source_address: str | None = None,
verbose: bool = False,
prefer_system_certs: bool = False,
client_cert: dict[str, str | None] = None,
client_cert: dict[str, str | None] | None = None,
verify: bool = True,
legacy_ssl_support: bool = False,
**_,
@@ -234,7 +239,7 @@ class RequestHandler(abc.ABC):
self._logger = logger
self.headers = headers or {}
self.cookiejar = cookiejar if cookiejar is not None else YoutubeDLCookieJar()
self.timeout = float(timeout or 20)
self.timeout = float(timeout or DEFAULT_TIMEOUT)
self.proxies = proxies or {}
self.source_address = source_address
self.verbose = verbose
@@ -244,10 +249,10 @@ class RequestHandler(abc.ABC):
self.legacy_ssl_support = legacy_ssl_support
super().__init__()
def _make_sslcontext(self):
def _make_sslcontext(self, legacy_ssl_support=None):
return make_ssl_context(
verify=self.verify,
legacy_support=self.legacy_ssl_support,
legacy_support=legacy_ssl_support if legacy_ssl_support is not None else self.legacy_ssl_support,
use_certifi=not self.prefer_system_certs,
**self._client_cert,
)
@@ -255,6 +260,33 @@ class RequestHandler(abc.ABC):
def _merge_headers(self, request_headers):
return HTTPHeaderDict(self.headers, request_headers)
def _prepare_headers(self, request: Request, headers: HTTPHeaderDict) -> None: # noqa: B027
"""Additional operations to prepare headers before building. To be extended by subclasses.
@param request: Request object
@param headers: Merged headers to prepare
"""
def _get_headers(self, request: Request) -> dict[str, str]:
"""
Get headers for external use.
Subclasses may define a _prepare_headers method to modify headers after merge but before building.
"""
headers = self._merge_headers(request.headers)
self._prepare_headers(request, headers)
if request.extensions.get('keep_header_casing'):
return headers.sensitive()
return dict(headers)
def _calculate_timeout(self, request):
return float(request.extensions.get('timeout') or self.timeout)
def _get_cookiejar(self, request):
cookiejar = request.extensions.get('cookiejar')
return self.cookiejar if cookiejar is None else cookiejar
def _get_proxies(self, request):
return (request.proxies or self.proxies).copy()
def _check_url_scheme(self, request: Request):
scheme = urllib.parse.urlparse(request.url).scheme.lower()
if self._SUPPORTED_URL_SCHEMES is not None and scheme not in self._SUPPORTED_URL_SCHEMES:
@@ -302,6 +334,8 @@ class RequestHandler(abc.ABC):
"""Check extensions for unsupported extensions. Subclasses should extend this."""
assert isinstance(extensions.get('cookiejar'), (YoutubeDLCookieJar, NoneType))
assert isinstance(extensions.get('timeout'), (float, int, NoneType))
assert isinstance(extensions.get('legacy_ssl'), (bool, NoneType))
assert isinstance(extensions.get('keep_header_casing'), (bool, NoneType))
def _validate(self, request):
self._check_url_scheme(request)
@@ -329,7 +363,7 @@ class RequestHandler(abc.ABC):
"""Handle a request from start to finish. Redefine in subclasses."""
pass
def close(self):
def close(self): # noqa: B027
pass
@classproperty
@@ -366,11 +400,11 @@ class Request:
self,
url: str,
data: RequestData = None,
headers: typing.Mapping = None,
proxies: dict = None,
query: dict = None,
method: str = None,
extensions: dict = None
headers: typing.Mapping | None = None,
proxies: dict | None = None,
query: dict | None = None,
method: str | None = None,
extensions: dict | None = None,
):
self._headers = HTTPHeaderDict()
@@ -445,7 +479,7 @@ class Request:
@headers.setter
def headers(self, new_headers: Mapping):
"""Replaces headers of the request. If not a CaseInsensitiveDict, it will be converted to one."""
"""Replaces headers of the request. If not a HTTPHeaderDict, it will be converted to one."""
if isinstance(new_headers, HTTPHeaderDict):
self._headers = new_headers
elif isinstance(new_headers, Mapping):
@@ -453,9 +487,10 @@ class Request:
else:
raise TypeError('headers must be a mapping')
def update(self, url=None, data=None, headers=None, query=None):
def update(self, url=None, data=None, headers=None, query=None, extensions=None):
self.data = data if data is not None else self.data
self.headers.update(headers or {})
self.extensions.update(extensions or {})
self.url = update_url_query(url or self.url, query or {})
def copy(self):
@@ -470,6 +505,7 @@ class Request:
HEADRequest = functools.partial(Request, method='HEAD')
PATCHRequest = functools.partial(Request, method='PATCH')
PUTRequest = functools.partial(Request, method='PUT')
@@ -486,15 +522,18 @@ class Response(io.IOBase):
@param headers: response headers.
@param status: Response HTTP status code. Default is 200 OK.
@param reason: HTTP status reason. Will use built-in reasons based on status code if not provided.
@param extensions: Dictionary of handler-specific response extensions.
"""
def __init__(
self,
fp: typing.IO,
fp: io.IOBase,
url: str,
headers: Mapping[str, str],
status: int = 200,
reason: str = None):
reason: str | None = None,
extensions: dict | None = None,
):
self.fp = fp
self.headers = Message()
@@ -506,11 +545,12 @@ class Response(io.IOBase):
self.reason = reason or HTTPStatus(status).phrase
except ValueError:
self.reason = None
self.extensions = extensions or {}
def readable(self):
return self.fp.readable()
def read(self, amt: int = None) -> bytes:
def read(self, amt: int | None = None) -> bytes:
# Expected errors raised here should be of type RequestError or subclasses.
# Subclasses should redefine this method with more precise error handling.
try:

View File

@@ -1,9 +1,8 @@
from __future__ import annotations
import typing
import urllib.error
from ..utils import YoutubeDLError, deprecation_warning
from ..utils import YoutubeDLError
if typing.TYPE_CHECKING:
from .common import RequestHandler, Response
@@ -14,7 +13,7 @@ class RequestError(YoutubeDLError):
self,
msg: str | None = None,
cause: Exception | str | None = None,
handler: RequestHandler = None
handler: RequestHandler = None,
):
self.handler = handler
self.cause = cause
@@ -75,10 +74,10 @@ class HTTPError(RequestError):
class IncompleteRead(TransportError):
def __init__(self, partial, expected=None, **kwargs):
def __init__(self, partial: int, expected: int | None = None, **kwargs):
self.partial = partial
self.expected = expected
msg = f'{len(partial)} bytes read'
msg = f'{partial} bytes read'
if expected is not None:
msg += f', {expected} more expected'
@@ -101,117 +100,4 @@ class ProxyError(TransportError):
pass
class _CompatHTTPError(urllib.error.HTTPError, HTTPError):
"""
Provides backwards compatibility with urllib.error.HTTPError.
Do not use this class directly, use HTTPError instead.
"""
def __init__(self, http_error: HTTPError):
super().__init__(
url=http_error.response.url,
code=http_error.status,
msg=http_error.msg,
hdrs=http_error.response.headers,
fp=http_error.response
)
self._closer.file = None # Disable auto close
self._http_error = http_error
HTTPError.__init__(self, http_error.response, redirect_loop=http_error.redirect_loop)
@property
def status(self):
return self._http_error.status
@status.setter
def status(self, value):
return
@property
def reason(self):
return self._http_error.reason
@reason.setter
def reason(self, value):
return
@property
def headers(self):
deprecation_warning('HTTPError.headers is deprecated, use HTTPError.response.headers instead')
return self._http_error.response.headers
@headers.setter
def headers(self, value):
return
def info(self):
deprecation_warning('HTTPError.info() is deprecated, use HTTPError.response.headers instead')
return self.response.headers
def getcode(self):
deprecation_warning('HTTPError.getcode is deprecated, use HTTPError.status instead')
return self.status
def geturl(self):
deprecation_warning('HTTPError.geturl is deprecated, use HTTPError.response.url instead')
return self.response.url
@property
def code(self):
deprecation_warning('HTTPError.code is deprecated, use HTTPError.status instead')
return self.status
@code.setter
def code(self, value):
return
@property
def url(self):
deprecation_warning('HTTPError.url is deprecated, use HTTPError.response.url instead')
return self.response.url
@url.setter
def url(self, value):
return
@property
def hdrs(self):
deprecation_warning('HTTPError.hdrs is deprecated, use HTTPError.response.headers instead')
return self.response.headers
@hdrs.setter
def hdrs(self, value):
return
@property
def filename(self):
deprecation_warning('HTTPError.filename is deprecated, use HTTPError.response.url instead')
return self.response.url
@filename.setter
def filename(self, value):
return
def __getattr__(self, name):
# File operations are passed through the response.
# Warn for some commonly used ones
passthrough_warnings = {
'read': 'response.read()',
# technically possibly due to passthrough, but we should discourage this
'get_header': 'response.get_header()',
'readable': 'response.readable()',
'closed': 'response.closed',
'tell': 'response.tell()',
}
if name in passthrough_warnings:
deprecation_warning(f'HTTPError.{name} is deprecated, use HTTPError.{passthrough_warnings[name]} instead')
return super().__getattr__(name)
def __str__(self):
return str(self._http_error)
def __repr__(self):
return repr(self._http_error)
network_exceptions = (HTTPError, TransportError)

View File

@@ -0,0 +1,155 @@
from __future__ import annotations
import re
from abc import ABC
from dataclasses import dataclass
from typing import Any
from .common import RequestHandler, register_preference, Request
from .exceptions import UnsupportedRequest
from ..compat.types import NoneType
from ..utils import classproperty, join_nonempty
from ..utils.networking import std_headers, HTTPHeaderDict
@dataclass(order=True, frozen=True)
class ImpersonateTarget:
"""
A target for browser impersonation.
Parameters:
@param client: the client to impersonate
@param version: the client version to impersonate
@param os: the client OS to impersonate
@param os_version: the client OS version to impersonate
Note: None is used to indicate to match any.
"""
client: str | None = None
version: str | None = None
os: str | None = None
os_version: str | None = None
def __post_init__(self):
if self.version and not self.client:
raise ValueError('client is required if version is set')
if self.os_version and not self.os:
raise ValueError('os is required if os_version is set')
def __contains__(self, target: ImpersonateTarget):
if not isinstance(target, ImpersonateTarget):
return False
return (
(self.client is None or target.client is None or self.client == target.client)
and (self.version is None or target.version is None or self.version == target.version)
and (self.os is None or target.os is None or self.os == target.os)
and (self.os_version is None or target.os_version is None or self.os_version == target.os_version)
)
def __str__(self):
return f'{join_nonempty(self.client, self.version)}:{join_nonempty(self.os, self.os_version)}'.rstrip(':')
@classmethod
def from_str(cls, target: str):
mobj = re.fullmatch(r'(?:(?P<client>[^:-]+)(?:-(?P<version>[^:-]+))?)?(?::(?:(?P<os>[^:-]+)(?:-(?P<os_version>[^:-]+))?)?)?', target)
if not mobj:
raise ValueError(f'Invalid impersonate target "{target}"')
return cls(**mobj.groupdict())
class ImpersonateRequestHandler(RequestHandler, ABC):
"""
Base class for request handlers that support browser impersonation.
This provides a method for checking the validity of the impersonate extension,
which can be used in _check_extensions.
Impersonate targets consist of a client, version, os and os_ver.
See the ImpersonateTarget class for more details.
The following may be defined:
- `_SUPPORTED_IMPERSONATE_TARGET_MAP`: a dict mapping supported targets to custom object.
Any Request with an impersonate target not in this list will raise an UnsupportedRequest.
Set to None to disable this check.
Note: Entries are in order of preference
Parameters:
@param impersonate: the default impersonate target to use for requests.
Set to None to disable impersonation.
"""
_SUPPORTED_IMPERSONATE_TARGET_MAP: dict[ImpersonateTarget, Any] = {}
def __init__(self, *, impersonate: ImpersonateTarget = None, **kwargs):
super().__init__(**kwargs)
self.impersonate = impersonate
def _check_impersonate_target(self, target: ImpersonateTarget):
assert isinstance(target, (ImpersonateTarget, NoneType))
if target is None or not self.supported_targets:
return
if not self.is_supported_target(target):
raise UnsupportedRequest(f'Unsupported impersonate target: {target}')
def _check_extensions(self, extensions):
super()._check_extensions(extensions)
if 'impersonate' in extensions:
self._check_impersonate_target(extensions.get('impersonate'))
def _validate(self, request):
super()._validate(request)
self._check_impersonate_target(self.impersonate)
def _resolve_target(self, target: ImpersonateTarget | None):
"""Resolve a target to a supported target."""
if target is None:
return
for supported_target in self.supported_targets:
if target in supported_target:
if self.verbose:
self._logger.stdout(
f'{self.RH_NAME}: resolved impersonate target {target} to {supported_target}')
return supported_target
@classproperty
def supported_targets(cls) -> tuple[ImpersonateTarget, ...]:
return tuple(cls._SUPPORTED_IMPERSONATE_TARGET_MAP.keys())
def is_supported_target(self, target: ImpersonateTarget):
assert isinstance(target, ImpersonateTarget)
return self._resolve_target(target) is not None
def _get_request_target(self, request):
"""Get the requested target for the request"""
return self._resolve_target(request.extensions.get('impersonate') or self.impersonate)
def _prepare_impersonate_headers(self, request: Request, headers: HTTPHeaderDict) -> None: # noqa: B027
"""Additional operations to prepare headers before building. To be extended by subclasses.
@param request: Request object
@param headers: Merged headers to prepare
"""
def _get_impersonate_headers(self, request: Request) -> dict[str, str]:
"""
Get headers for external impersonation use.
Subclasses may define a _prepare_impersonate_headers method to modify headers after merge but before building.
"""
headers = self._merge_headers(request.headers)
if self._get_request_target(request) is not None:
# remove all headers present in std_headers
# TODO: change this to not depend on std_headers
for k, v in std_headers.items():
if headers.get(k) == v:
headers.pop(k)
self._prepare_impersonate_headers(request, headers)
if request.extensions.get('keep_header_casing'):
return headers.sensitive()
return dict(headers)
@register_preference(ImpersonateRequestHandler)
def impersonate_preference(rh, request):
if request.extensions.get('impersonate') or rh.impersonate:
return 1000
return 0

View File

@@ -0,0 +1,23 @@
from __future__ import annotations
import abc
from .common import RequestHandler, Response
class WebSocketResponse(Response):
def send(self, message: bytes | str):
"""
Send a message to the server.
@param message: The message to send. A string (str) is sent as a text frame, bytes is sent as a binary frame.
"""
raise NotImplementedError
def recv(self):
raise NotImplementedError
class WebSocketRequestHandler(RequestHandler, abc.ABC):
pass