Clean up codebase and improve file loading

- Moved plugins to proper sub groups (autopairs, code_minimap, colorize, commentzar, info_bar, markdown_preview, prettify_json, search_replace, tabs_bar, telescope, toggle_source_view, lsp_client)
- Add filter_out_loaded_files to prevent opening already-loaded files
- Add INDEPENDENT source view state
- Fix cursor scroll position on buffer switch
- Fix signal blocking during file load
- Fix word boundary in completion provider
- Refactor code events into single events module
This commit is contained in:
2026-03-08 00:51:28 -06:00
parent a52d5243ab
commit 99dc917de3
229 changed files with 8809 additions and 756 deletions

View File

@@ -0,0 +1,3 @@
"""
Pligin Module
"""

View File

@@ -0,0 +1,3 @@
"""
Pligin Package
"""

View File

@@ -0,0 +1,50 @@
# Python imports
# Lib imports
import gi
gi.require_version('GtkSource', '4')
from gi.repository.GtkSource import Map
from gi.repository import Pango
# Application imports
class CodeMiniMap(Map):
def __init__(self):
super(CodeMiniMap, self).__init__()
self._setup_styling()
self._setup_signals()
self._subscribe_to_events()
self._load_widgets()
self.show()
def _setup_styling(self):
ctx = self.get_style_context()
ctx.add_class("mini-view")
self.set_hexpand(False)
self._set_font_desc()
def _setup_signals(self):
...
def _subscribe_to_events(self):
event_system.subscribe(f"set-mini-view", self.set_smini_view)
def _load_widgets(self):
...
def _set_font_desc(self):
default_font = 'Monospace 1'
desc = Pango.FontDescription(default_font)
desc.set_size(Pango.SCALE) # Set size to 1pt
desc.set_family('BuilderBlocks,' + desc.get_family())
self.set_property('font-desc', desc)
def set_smini_view(self, source_view):
self.set_view(source_view)

View File

@@ -0,0 +1,7 @@
{
"name": "Code MiniMap",
"author": "ITDominator",
"version": "0.0.1",
"support": "",
"requests": {}
}

View File

@@ -0,0 +1,32 @@
# Python imports
# Lib imports
# Application imports
from libs.event_factory import Event_Factory, Code_Event_Types
from plugins.plugin_types import PluginCode
from .code_minimap import CodeMiniMap
code_minimap = CodeMiniMap()
class Plugin(PluginCode):
def __init__(self):
super(Plugin, self).__init__()
def _controller_message(self, event: Code_Event_Types.CodeEvent):
if isinstance(event, Code_Event_Types.FocusedViewEvent):
code_minimap.set_smini_view(event.view)
def load(self):
editors_container = self.request_ui_element("editors-container")
editors_container.add( code_minimap )
def run(self):
...

View File

@@ -0,0 +1,3 @@
"""
Pligin Module
"""

View File

@@ -0,0 +1,3 @@
"""
Pligin Package
"""

View File

@@ -0,0 +1,96 @@
# Python imports
# Lib imports
import gi
gi.require_version('Gtk', '3.0')
from gi.repository import Gtk
from gi.repository import Pango
from gi.repository import Gio
# Application imports
class InfoBarWidget(Gtk.Box):
""" docstring for InfoBarWidget. """
def __init__(self):
super(InfoBarWidget, self).__init__()
self._setup_styling()
self._setup_signals()
self._subscribe_to_events()
self._load_widgets()
self.show_all()
def _setup_styling(self):
self.set_margin_start(25)
self.set_margin_end(25)
def _setup_signals(self):
...
def _subscribe_to_events(self):
...
def _load_widgets(self):
self.path_label = Gtk.Label(label = "...")
self.line_char_label = Gtk.Label(label = "1:0")
self.encoding_label = Gtk.Label(label = "utf-8")
self.file_type_label = Gtk.Label(label = "buffer")
self.add(self.path_label)
self.add(self.line_char_label)
self.add(self.encoding_label)
self.add(self.file_type_label)
self.path_label.set_hexpand(True)
self.path_label.set_ellipsize(Pango.EllipsizeMode.START)
self.path_label.set_single_line_mode(True)
self.path_label.set_max_width_chars(48)
self.line_char_label.set_hexpand(True)
self.encoding_label.set_hexpand(True)
self.file_type_label.set_hexpand(True)
def _set_info_labels(
self,
path: Gio.File or str = None,
line_char: str = None,
file_type: str = None,
encoding_type: str = None
):
self._set_path_label(path)
self._set_line_char_label(line_char)
self._set_file_type_label(file_type)
self._set_encoding_label(encoding_type)
def _set_path_label(self, gfile: Gio.File or str = "..."):
gfile = "" if not gfile else gfile
if isinstance(gfile, str):
self.path_label.set_text( gfile )
self.path_label.set_tooltip_text( gfile )
else:
self.path_label.set_text( gfile.get_path() )
self.path_label.set_tooltip_text( gfile.get_path() )
def _set_line_char_label(self, line_char = "1:1"):
line_char = "1:1" if not line_char else line_char
self.line_char_label.set_text(line_char)
def _set_file_type_label(self, file_type = "buffer"):
file_type = "buffer" if not file_type else file_type
self.file_type_label.set_text(file_type)
def _set_encoding_label(self, encoding_type = "utf-8"):
encoding_type = "utf-8" if not encoding_type else encoding_type
self.encoding_label.set_text(encoding_type)

View File

@@ -0,0 +1,7 @@
{
"name": "Info Bar",
"author": "ITDominator",
"version": "0.0.1",
"support": "",
"requests": {}
}

View File

@@ -0,0 +1,32 @@
# Python imports
# Lib imports
# Application imports
from libs.event_factory import Event_Factory, Code_Event_Types
from plugins.plugin_types import PluginCode
from .info_bar_widget import InfoBarWidget
info_bar_widget = InfoBarWidget()
class Plugin(PluginCode):
def __init__(self):
super(Plugin, self).__init__()
def _controller_message(self, event: Code_Event_Types.CodeEvent):
if isinstance(event, Code_Event_Types.SetInfoLabelsEvent):
info_bar_widget._set_info_labels(*event.info)
def load(self):
header = self.request_ui_element("header-container")
header.add( info_bar_widget )
def run(self):
...

View File

@@ -0,0 +1,3 @@
"""
Pligin Module
"""

View File

@@ -0,0 +1,3 @@
"""
Pligin Package
"""

View File

@@ -0,0 +1,151 @@
{
"_description": "The parameters sent by the client when initializing the language server with the \"initialize\" request. More details at https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#initialize",
"processId": "os.getpid()",
"clientInfo": {
"name": "LSP Manager",
"version": "0.0.1"
},
"locale": "en",
"rootPath": "repository_absolute_path",
"rootUri": "pathlib.Path(repository_absolute_path).as_uri()",
"capabilities": {
"textDocument": {
"completion": {
"dynamicRegistration": true,
"contextSupport": true,
"completionItem": {
"snippetSupport": false,
"commitCharactersSupport": true,
"documentationFormat": [
"markdown",
"plaintext"
],
"deprecatedSupport": true,
"preselectSupport": true,
"tagSupport": {
"valueSet": [
1
]
},
"insertReplaceSupport": false,
"resolveSupport": {
"properties": [
"documentation",
"detail",
"additionalTextEdits"
]
},
"insertTextModeSupport": {
"valueSet": [
1,
2
]
},
"labelDetailsSupport": true
},
"insertTextMode": 2,
"completionItemKind": {
"valueSet": [
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
15,
16,
17,
18,
19,
20,
21,
22,
23,
24,
25
]
},
"completionList": {
"itemDefaults": [
"commitCharacters",
"editRange",
"insertTextFormat",
"insertTextMode"
]
}
},
"hover": {
"dynamicRegistration": true,
"contentFormat": [
"markdown",
"plaintext"
]
},
"signatureHelp": {
"dynamicRegistration": true,
"signatureInformation": {
"documentationFormat": [
"markdown",
"plaintext"
],
"parameterInformation": {
"labelOffsetSupport": true
},
"activeParameterSupport": true
},
"contextSupport": true
},
"definition": {
"dynamicRegistration": true,
"linkSupport": true
},
"references": {
"dynamicRegistration": true
},
"typeDefinition": {
"dynamicRegistration": true,
"linkSupport": true
},
"implementation": {
"dynamicRegistration": true,
"linkSupport": true
},
"colorProvider": {
"dynamicRegistration": true
},
"declaration": {
"dynamicRegistration": true,
"linkSupport": true
},
"callHierarchy": {
"dynamicRegistration": true
},
"inlayHint": {
"dynamicRegistration": true,
"resolveSupport": {
"properties": [
"tooltip",
"textEdits",
"label.tooltip",
"label.location",
"label.command"
]
}
},
"diagnostic": {
"dynamicRegistration": true,
"relatedDocumentSupport": false
}
}
},
"trace": "verbose",
"workspaceFolders": "[\n {\n \"uri\": pathlib.Path(repository_absolute_path).as_uri(),\n \"name\": os.path.basename(repository_absolute_path),\n }\n ]"
}

View File

@@ -0,0 +1,365 @@
{
"java": {
"info": "https://download.eclipse.org/jdtls/",
"info-init-options": "https://github.com/eclipse-jdtls/eclipse.jdt.ls/wiki/Running-the-JAVA-LS-server-from-the-command-line",
"info-import-build": "https://www.javahotchocolate.com/tutorials/build-path.html",
"info-external-class-paths": "https://github.com/eclipse-jdtls/eclipse.jdt.ls/issues/3291",
"link": "https://download.eclipse.org/jdtls/milestones/?d",
"command": "lsp-ws-proxy --listen 4114 -- jdtls",
"alt-command": "lsp-ws-proxy -- jdtls",
"alt-command2": "java-language-server",
"socket": "ws://127.0.0.1:9999/java",
"socket-two": "ws://127.0.0.1:9999/?name=jdtls",
"alt-socket": "ws://127.0.0.1:9999/?name=java-language-server",
"initialization-options": {
"bundles": [
"intellicode-core.jar"
],
"workspaceFolders": [
"file://{workspace.folder}"
],
"extendedClientCapabilities": {
"classFileContentsSupport": true,
"executeClientCommandSupport": false
},
"settings": {
"java": {
"autobuild": {
"enabled": true
},
"jdt": {
"ls": {
"javac": {
"enabled": true
},
"java": {
"home": "{user.home}/Portable_Apps/sdks/javasdk/jdk-22.0.2"
},
"lombokSupport": {
"enabled": true
},
"protobufSupport":{
"enabled": true
},
"androidSupport": {
"enabled": true
}
}
},
"configuration": {
"updateBuildConfiguration": "automatic",
"maven": {
"userSettings": "{user.home}/.config/jdtls/settings.xml",
"globalSettings": "{user.home}/.config/jdtls/settings.xml"
},
"runtimes": [
{
"name": "JavaSE-17",
"path": "/usr/lib/jvm/java-17-openjdk",
"javadoc": "https://docs.oracle.com/en/java/javase/17/docs/api/",
"default": false
},
{
"name": "JavaSE-22",
"path": "{user.home}/Portable_Apps/sdks/javasdk/jdk-22.0.2",
"javadoc": "https://docs.oracle.com/en/java/javase/22/docs/api/",
"default": true
}
]
},
"classPath": [
"{user.home}/.config/jdtls/m2/repository/**/*-sources.jar",
"lib/**/*-sources.jar"
],
"docPath": [
"{user.home}/.config/jdtls/m2/repository/**/*-javadoc.jar",
"lib/**/*-javadoc.jar"
],
"project": {
"encoding": "ignore",
"outputPath": "bin",
"referencedLibraries": [
"{user.home}/.config/jdtls/m2/repository/**/*.jar",
"lib/**/*.jar"
],
"importOnFirstTimeStartup": "automatic",
"importHint": true,
"resourceFilters": [
"node_modules",
"\\.git"
],
"sourcePaths": [
"src",
"{user.home}/.config/jdtls/m2/repository/**/*.jar"
]
},
"sources": {
"organizeImports": {
"starThreshold": 99,
"staticStarThreshold": 99
}
},
"imports": {
"gradle": {
"wrapper": {
"checksums": []
}
}
},
"import": {
"maven": {
"enabled": true,
"offline": {
"enabled": false
},
"disableTestClasspathFlag": false
},
"gradle": {
"enabled": false,
"wrapper": {
"enabled": true
},
"version": "",
"home": "abs(static/gradle-7.3.3)",
"java": {
"home": "abs(static/launch_jres/17.0.6-linux-x86_64)"
},
"offline": {
"enabled": false
},
"arguments": [],
"jvmArguments": [],
"user": {
"home": ""
},
"annotationProcessing": {
"enabled": true
}
},
"exclusions": [
"**/node_modules/**",
"**/.metadata/**",
"**/archetype-resources/**",
"**/META-INF/maven/**"
],
"generatesMetadataFilesAtProjectRoot": false
},
"maven": {
"downloadSources": true,
"updateSnapshots": true
},
"silentNotification": true,
"contentProvider": {
"preferred": "fernflower"
},
"signatureHelp": {
"enabled": true,
"description": {
"enabled": true
}
},
"completion": {
"enabled": true,
"engine": "ecj",
"matchCase": "firstletter",
"maxResults": 25,
"guessMethodArguments": true,
"lazyResolveTextEdit": {
"enabled": true
},
"postfix": {
"enabled": true
},
"favoriteStaticMembers": [
"org.junit.Assert.*",
"org.junit.Assume.*",
"org.junit.jupiter.api.Assertions.*",
"org.junit.jupiter.api.Assumptions.*",
"org.junit.jupiter.api.DynamicContainer.*",
"org.junit.jupiter.api.DynamicTest.*"
],
"importOrder": [
"#",
"java",
"javax",
"org",
"com"
]
},
"references": {
"includeAccessors": true,
"includeDecompiledSources": true
},
"codeGeneration": {
"toString": {
"template": "${object.className}{${member.name()}=${member.value}, ${otherMembers}}"
},
"insertionLocation": "afterCursor",
"useBlocks": true
},
"implementationsCodeLens": {
"enabled": true
},
"referencesCodeLens": {
"enabled": true
},
"progressReports": {
"enabled": false
},
"saveActions": {
"organizeImports": true
}
}
}
}
},
"python": {
"info": "https://github.com/python-lsp/python-lsp-server",
"command": "lsp-ws-proxy -- pylsp",
"alt-command": "pylsp",
"alt-command2": "lsp-ws-proxy --listen 4114 -- pylsp",
"alt-command3": "pylsp --ws --port 4114",
"socket": "ws://127.0.0.1:9999/python",
"socket-two": "ws://127.0.0.1:9999/?name=pylsp",
"initialization-options": {
"pylsp": {
"rope": {
"ropeFolder": "{user.home}/.config/newton/lsps/ropeproject"
},
"plugins": {
"ruff": {
"enabled": true,
"extendSelect": ["I"],
"lineLength": 80
},
"pycodestyle": {
"enabled": false
},
"pyflakes": {
"enabled": false
},
"pylint": {
"enabled": true
},
"mccabe": {
"enabled": false
},
"pylsp_rope": {
"rename": false
},
"rope_rename": {
"enabled": false
},
"rope_autoimport": {
"enabled": true
},
"rope_completion": {
"enabled": false,
"eager": false
},
"jedi_rename": {
"enabled": true
},
"jedi_completion": {
"enabled": true,
"include_class_objects": true,
"include_function_objects": true,
"fuzzy": false
},
"jedi": {
"root_dir": "file://{workspace.folder}",
"extra_paths": [
"{user.home}/Portable_Apps/py-venvs/pylsp-venv/venv/lib/python3.10/site-packages"
]
}
}
}
}
},
"python - jedi-language-server": {
"hidden": true,
"info": "https://pypi.org/project/jedi-language-server/",
"command": "jedi-language-server",
"alt-command": "lsp-ws-proxy --listen 3030 -- jedi-language-server",
"socket": "ws://127.0.0.1:9999/python",
"socket-two": "ws://127.0.0.1:9999/?name=jedi-language-server",
"initialization-options": {
"jediSettings": {
"autoImportModules": [],
"caseInsensitiveCompletion": true,
"debug": false
},
"completion": {
"disableSnippets": false,
"resolveEagerly": false,
"ignorePatterns": []
},
"markupKindPreferred": "markdown",
"workspace": {
"extraPaths": [
"{user.home}/Portable_Apps/py-venvs/pylsp-venv/venv/lib/python3.10/site-packages"
],
"environmentPath": "{user.home}/Portable_Apps/py-venvs/gtk-apps-venv/venv/bin/python",
"symbols": {
"ignoreFolders": [
".nox",
".tox",
".venv",
"__pycache__",
"venv"
],
"maxSymbols": 20
}
}
}
},
"cpp": {
"info": "https://clangd.llvm.org/",
"command": "lsp-ws-proxy -- clangd",
"alt-command": "clangd",
"socket": "ws://127.0.0.1:9999/cpp",
"socket-two": "ws://127.0.0.1:9999/?name=clangd",
"initialization-options": {}
},
"c": {
"hidden": true,
"info": "https://clangd.llvm.org/",
"command": "lsp-ws-proxy -- clangd",
"alt-command": "clangd",
"socket": "ws://127.0.0.1:9999/c",
"socket-two": "ws://127.0.0.1:9999/?name=clangd",
"initialization-options": {}
},
"go": {
"info": "https://pkg.go.dev/golang.org/x/tools/gopls#section-readme",
"command": "lsp-ws-proxy -- gopls",
"alt-command": "gopls",
"socket": "ws://127.0.0.1:9999/go",
"socket-two": "ws://127.0.0.1:9999/?name=gopls",
"initialization-options": {}
},
"typescript": {
"info": "https://github.com/typescript-language-server/typescript-language-server",
"command": "lsp-ws-proxy -- typescript-language-server",
"alt-command": "typescript-language-server --stdio",
"socket": "ws://127.0.0.1:9999/typescript",
"socket-two": "ws://127.0.0.1:9999/?name=ts",
"initialization-options": {}
},
"sh": {
"info": "",
"command": "",
"alt-command": "",
"socket": "ws://127.0.0.1:9999/bash",
"socket-two": "ws://127.0.0.1:9999/?name=shell",
"initialization-options": {}
},
"lua": {
"info": "https://github.com/LuaLS/lua-language-server",
"command": "lsp-ws-proxy -- lua-language-server",
"alt-command": "lua-language-server",
"socket": "ws://127.0.0.1:9999/lua",
"socket-two": "ws://127.0.0.1:9999/?name=lua",
"initialization-options": {}
}
}

View File

@@ -0,0 +1,3 @@
"""
Plugin Controller Module
"""

View File

@@ -0,0 +1,67 @@
# Python imports
import threading
# Lib imports
import gi
from gi.repository import GLib
# Application imports
from libs.dto.code.lsp.lsp_messages import get_message_str
from libs.dto.code.lsp.lsp_message_structs import LSPResponseTypes, ClientRequest, ClientNotification
from .lsp_controller_websocket import LSPControllerWebsocket
class LSPController(LSPControllerWebsocket):
def __init__(self):
super(LSPController, self).__init__()
# https://github.com/microsoft/multilspy/tree/main/src/multilspy/language_servers
# initialize-params-slim.json was created off of jedi_language_server one
# self._init_params = settings_manager.get_lsp_init_data()
self._language: str = ""
self._init_params: dict = {}
self._event_history: dict[str] = {}
try:
from os import path
import json
_USER_HOME = path.expanduser('~')
_SCRIPT_PTH = path.dirname( path.realpath(__file__) )
_LSP_INIT_CONFIG = f"{_SCRIPT_PTH}/../configs/initialize-params-slim.json"
with open(_LSP_INIT_CONFIG) as file:
data = file.read().replace("{user.home}", _USER_HOME)
self._init_params = json.loads(data)
except Exception as e:
logger.error( f"LSP Controller: {_LSP_INIT_CONFIG}\n\t\t{repr(e)}" )
self._message_id: int = -1
self._socket = None
self.read_lock = threading.Lock()
self.write_lock = threading.Lock()
def set_language(self, language):
self._language = language
def set_socket(self, socket: str):
self._socket = socket
def unset_socket(self):
self._socket = None
def send_notification(self, method: str, params: {} = {}):
self._send_message( ClientNotification(method, params) )
def send_request(self, method: str, params: {} = {}):
self._message_id += 1
self._event_history[self._message_id] = method
self._send_message( ClientRequest(self._message_id, method, params) )
def get_event_by_id(self, message_id: int):
if not message_id in self._event_history: return
return self._event_history[message_id]
def handle_lsp_response(self, lsp_response: LSPResponseTypes):
raise NotImplementedError

View File

@@ -0,0 +1,19 @@
# Python imports
# Lib imports
# Application imports
from .lsp_controller_events import LSPControllerEvents
from libs.dto.code.lsp.lsp_message_structs import ClientRequest, ClientNotification
class LSPControllerBase(LSPControllerEvents):
def _send_message(self, data: ClientRequest or ClientNotification):
raise NotImplementedError
def start_client(self):
raise NotImplementedError
def stop_client(self):
raise NotImplementedError

View File

@@ -0,0 +1,121 @@
# Python imports
import os
# Lib imports
from gi.repository import GLib
# Application imports
from libs.dto.code.lsp.lsp_messages import get_message_obj
from libs.dto.code.lsp.lsp_messages import didopen_notification
from libs.dto.code.lsp.lsp_messages import didsave_notification
from libs.dto.code.lsp.lsp_messages import didclose_notification
from libs.dto.code.lsp.lsp_messages import didchange_notification
from libs.dto.code.lsp.lsp_messages import completion_request
from libs.dto.code.lsp.lsp_messages import definition_request
from libs.dto.code.lsp.lsp_messages import references_request
from libs.dto.code.lsp.lsp_messages import symbols_request
class LSPControllerEvents:
def send_initialize_message(self, init_ops: dict, workspace_file: str, workspace_uri: str):
folder_name = os.path.basename(workspace_file)
self._init_params["processId"] = None
self._init_params["rootPath"] = workspace_file
self._init_params["rootUri"] = workspace_uri
self._init_params["workspaceFolders"] = [
{
"name": folder_name,
"uri": workspace_uri
}
]
self._init_params["initializationOptions"] = init_ops
self.send_request("initialize", self._init_params)
def send_initialized_message(self):
self.send_notification("initialized")
def _lsp_did_open(self, data: dict):
method = "textDocument/didOpen"
params = didopen_notification["params"]
params["textDocument"]["uri"] = data["uri"]
params["textDocument"]["languageId"] = data["language_id"]
params["textDocument"]["text"] = data["text"]
GLib.idle_add( self.send_notification, method, params )
def _lsp_did_save(self, data: dict):
method = "textDocument/didSave"
params = didsave_notification["params"]
params["textDocument"]["uri"] = data["uri"]
params["text"] = data["text"]
GLib.idle_add( self.send_notification, method, params )
def _lsp_did_close(self, data: dict):
method = "textDocument/didClose"
params = didclose_notification["params"]
params["textDocument"]["uri"] = data["uri"]
GLib.idle_add( self.send_notification, method, params )
def _lsp_did_change(self, data: dict):
method = "textDocument/didChange"
params = didchange_notification["params"]
params["textDocument"]["uri"] = data["uri"]
params["textDocument"]["languageId"] = data["language_id"]
params["textDocument"]["version"] = data["version"]
contentChanges = params["contentChanges"][0]
contentChanges["text"] = data["text"]
GLib.idle_add( self.send_notification, method, params )
# def _lsp_did_change(self, data: dict):
# method = "textDocument/didChange"
# params = didchange_notification_range["params"]
# params["textDocument"]["uri"] = data["uri"]
# params["textDocument"]["languageId"] = data["language_id"]
# params["textDocument"]["version"] = data["version"]
# contentChanges = params["contentChanges"][0]
# start = contentChanges["range"]["start"]
# end = contentChanges["range"]["end"]
# contentChanges["text"] = data["text"]
# start["line"] = data["line"]
# start["character"] = 0
# end["line"] = data["line"]
# end["character"] = data["column"]
# GLib.idle_add( self.send_notification, method, params )
def _lsp_definition(self, data: dict):
method = "textDocument/definition"
params = definition_request["params"]
params["textDocument"]["uri"] = data["uri"]
params["textDocument"]["languageId"] = data["language_id"]
params["textDocument"]["version"] = data["version"]
params["position"]["line"] = data["line"]
params["position"]["character"] = data["column"]
GLib.idle_add( self.send_request, method, params )
def _lsp_completion(self, data: dict):
method = "textDocument/completion"
params = completion_request["params"]
params["textDocument"]["uri"] = data["uri"]
params["textDocument"]["languageId"] = data["language_id"]
params["textDocument"]["version"] = data["version"]
params["position"]["line"] = data["line"]
params["position"]["character"] = data["column"]
GLib.idle_add( self.send_request, method, params )

View File

@@ -0,0 +1,57 @@
# Python imports
import traceback
import subprocess
# Lib imports
from gi.repository import GLib
# Application imports
# from libs import websockets
from libs.dto.code.lsp.lsp_messages import LEN_HEADER, TYPE_HEADER, get_message_str, get_message_obj
from libs.dto.code.lsp.lsp_message_structs import \
LSPResponseTypes, ClientRequest, ClientNotification, LSPResponseRequest, LSPResponseNotification, LSPIDResponseNotification
from .lsp_controller_base import LSPControllerBase
from .websocket_client import WebsocketClient
class LSPControllerWebsocket(LSPControllerBase):
def _send_message(self, data: ClientRequest or ClientNotification):
if not data: return
message_str = get_message_str(data)
message_size = len(message_str)
message = f"Content-Length: {message_size}\r\n\r\n{message_str}"
logger.debug(f"Client: {message_str}")
self.ws_client.send(message_str)
def start_client(self):
self.ws_client = WebsocketClient()
self.ws_client.set_socket(self._socket)
self.ws_client.set_callback(self._monitor_lsp_response)
self.ws_client.start_client()
return self.ws_client
def stop_client(self):
if not hasattr(self, "ws_client"): return
self.ws_client.close_client()
def _monitor_lsp_response(self, data: None or {}):
if not data: return
message = get_message_obj(data)
keys = message.keys()
lsp_response = None
if "result" in keys:
lsp_response = LSPResponseRequest(**get_message_obj(data))
if "method" in keys:
lsp_response = LSPResponseNotification(**get_message_obj(data)) if not "id" in keys else LSPIDResponseNotification( **get_message_obj(data) )
if not lsp_response: return
GLib.idle_add(self.handle_lsp_response, lsp_response)

View File

@@ -0,0 +1,62 @@
# Python imports
import json
import threading
# Lib imports
# Application imports
from ..libs import websocket
class WebsocketClient:
def __init__(self):
self.ws = None
self._socket = None
self._connected = threading.Event()
def set_socket(self, socket: str):
self._socket = socket
def unset_socket(self):
self._socket = None
def send(self, message: str):
self.ws.send(message)
def on_message(self, ws, message: dict):
self.respond(message)
def on_error(self, ws, error: str):
logger.debug(f"WS Error: {error}")
def on_close(self, ws, close_status_code: int, close_msg: str):
logger.debug("WS Closed...")
def on_open(self, ws):
self._connected.set()
logger.debug("WS opened connection...")
def wait_for_connection(self, timeout: float = 5.0) -> bool:
return self._connected.wait(timeout)
def set_callback(self, callback: object):
self.respond = callback
def close_client(self):
self.ws.close()
@daemon_threaded
def start_client(self):
if not self._socket:
raise Exception("Socket address isn't set so cannot start WebsocketClient listener...")
# websocket.enableTrace(True)
self.ws = websocket.WebSocketApp(self._socket,
on_open = self.on_open,
on_message = self.on_message,
on_error = self.on_error,
on_close = self.on_close)
self.ws.run_forever(reconnect = 0.5)

View File

@@ -0,0 +1,3 @@
"""
Pligin Libs Module
"""

View File

@@ -0,0 +1,27 @@
"""
__init__.py
websocket - WebSocket client library for Python
Copyright 2024 engn33r
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from ._abnf import *
from ._app import WebSocketApp as WebSocketApp, setReconnect as setReconnect
from ._core import *
from ._exceptions import *
from ._logging import *
from ._socket import *
__version__ = "1.8.0"

View File

@@ -0,0 +1,453 @@
import array
import os
import struct
import sys
from threading import Lock
from typing import Callable, Optional, Union
from ._exceptions import WebSocketPayloadException, WebSocketProtocolException
from ._utils import validate_utf8
"""
_abnf.py
websocket - WebSocket client library for Python
Copyright 2024 engn33r
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
try:
# If wsaccel is available, use compiled routines to mask data.
# wsaccel only provides around a 10% speed boost compared
# to the websocket-client _mask() implementation.
# Note that wsaccel is unmaintained.
from wsaccel.xormask import XorMaskerSimple
def _mask(mask_value: array.array, data_value: array.array) -> bytes:
mask_result: bytes = XorMaskerSimple(mask_value).process(data_value)
return mask_result
except ImportError:
# wsaccel is not available, use websocket-client _mask()
native_byteorder = sys.byteorder
def _mask(mask_value: array.array, data_value: array.array) -> bytes:
datalen = len(data_value)
int_data_value = int.from_bytes(data_value, native_byteorder)
int_mask_value = int.from_bytes(
mask_value * (datalen // 4) + mask_value[: datalen % 4], native_byteorder
)
return (int_data_value ^ int_mask_value).to_bytes(datalen, native_byteorder)
__all__ = [
"ABNF",
"continuous_frame",
"frame_buffer",
"STATUS_NORMAL",
"STATUS_GOING_AWAY",
"STATUS_PROTOCOL_ERROR",
"STATUS_UNSUPPORTED_DATA_TYPE",
"STATUS_STATUS_NOT_AVAILABLE",
"STATUS_ABNORMAL_CLOSED",
"STATUS_INVALID_PAYLOAD",
"STATUS_POLICY_VIOLATION",
"STATUS_MESSAGE_TOO_BIG",
"STATUS_INVALID_EXTENSION",
"STATUS_UNEXPECTED_CONDITION",
"STATUS_BAD_GATEWAY",
"STATUS_TLS_HANDSHAKE_ERROR",
]
# closing frame status codes.
STATUS_NORMAL = 1000
STATUS_GOING_AWAY = 1001
STATUS_PROTOCOL_ERROR = 1002
STATUS_UNSUPPORTED_DATA_TYPE = 1003
STATUS_STATUS_NOT_AVAILABLE = 1005
STATUS_ABNORMAL_CLOSED = 1006
STATUS_INVALID_PAYLOAD = 1007
STATUS_POLICY_VIOLATION = 1008
STATUS_MESSAGE_TOO_BIG = 1009
STATUS_INVALID_EXTENSION = 1010
STATUS_UNEXPECTED_CONDITION = 1011
STATUS_SERVICE_RESTART = 1012
STATUS_TRY_AGAIN_LATER = 1013
STATUS_BAD_GATEWAY = 1014
STATUS_TLS_HANDSHAKE_ERROR = 1015
VALID_CLOSE_STATUS = (
STATUS_NORMAL,
STATUS_GOING_AWAY,
STATUS_PROTOCOL_ERROR,
STATUS_UNSUPPORTED_DATA_TYPE,
STATUS_INVALID_PAYLOAD,
STATUS_POLICY_VIOLATION,
STATUS_MESSAGE_TOO_BIG,
STATUS_INVALID_EXTENSION,
STATUS_UNEXPECTED_CONDITION,
STATUS_SERVICE_RESTART,
STATUS_TRY_AGAIN_LATER,
STATUS_BAD_GATEWAY,
)
class ABNF:
"""
ABNF frame class.
See http://tools.ietf.org/html/rfc5234
and http://tools.ietf.org/html/rfc6455#section-5.2
"""
# operation code values.
OPCODE_CONT = 0x0
OPCODE_TEXT = 0x1
OPCODE_BINARY = 0x2
OPCODE_CLOSE = 0x8
OPCODE_PING = 0x9
OPCODE_PONG = 0xA
# available operation code value tuple
OPCODES = (
OPCODE_CONT,
OPCODE_TEXT,
OPCODE_BINARY,
OPCODE_CLOSE,
OPCODE_PING,
OPCODE_PONG,
)
# opcode human readable string
OPCODE_MAP = {
OPCODE_CONT: "cont",
OPCODE_TEXT: "text",
OPCODE_BINARY: "binary",
OPCODE_CLOSE: "close",
OPCODE_PING: "ping",
OPCODE_PONG: "pong",
}
# data length threshold.
LENGTH_7 = 0x7E
LENGTH_16 = 1 << 16
LENGTH_63 = 1 << 63
def __init__(
self,
fin: int = 0,
rsv1: int = 0,
rsv2: int = 0,
rsv3: int = 0,
opcode: int = OPCODE_TEXT,
mask_value: int = 1,
data: Union[str, bytes, None] = "",
) -> None:
"""
Constructor for ABNF. Please check RFC for arguments.
"""
self.fin = fin
self.rsv1 = rsv1
self.rsv2 = rsv2
self.rsv3 = rsv3
self.opcode = opcode
self.mask_value = mask_value
if data is None:
data = ""
self.data = data
self.get_mask_key = os.urandom
def validate(self, skip_utf8_validation: bool = False) -> None:
"""
Validate the ABNF frame.
Parameters
----------
skip_utf8_validation: skip utf8 validation.
"""
if self.rsv1 or self.rsv2 or self.rsv3:
raise WebSocketProtocolException("rsv is not implemented, yet")
if self.opcode not in ABNF.OPCODES:
raise WebSocketProtocolException("Invalid opcode %r", self.opcode)
if self.opcode == ABNF.OPCODE_PING and not self.fin:
raise WebSocketProtocolException("Invalid ping frame.")
if self.opcode == ABNF.OPCODE_CLOSE:
l = len(self.data)
if not l:
return
if l == 1 or l >= 126:
raise WebSocketProtocolException("Invalid close frame.")
if l > 2 and not skip_utf8_validation and not validate_utf8(self.data[2:]):
raise WebSocketProtocolException("Invalid close frame.")
code = 256 * int(self.data[0]) + int(self.data[1])
if not self._is_valid_close_status(code):
raise WebSocketProtocolException("Invalid close opcode %r", code)
@staticmethod
def _is_valid_close_status(code: int) -> bool:
return code in VALID_CLOSE_STATUS or (3000 <= code < 5000)
def __str__(self) -> str:
return f"fin={self.fin} opcode={self.opcode} data={self.data}"
@staticmethod
def create_frame(data: Union[bytes, str], opcode: int, fin: int = 1) -> "ABNF":
"""
Create frame to send text, binary and other data.
Parameters
----------
data: str
data to send. This is string value(byte array).
If opcode is OPCODE_TEXT and this value is unicode,
data value is converted into unicode string, automatically.
opcode: int
operation code. please see OPCODE_MAP.
fin: int
fin flag. if set to 0, create continue fragmentation.
"""
if opcode == ABNF.OPCODE_TEXT and isinstance(data, str):
data = data.encode("utf-8")
# mask must be set if send data from client
return ABNF(fin, 0, 0, 0, opcode, 1, data)
def format(self) -> bytes:
"""
Format this object to string(byte array) to send data to server.
"""
if any(x not in (0, 1) for x in [self.fin, self.rsv1, self.rsv2, self.rsv3]):
raise ValueError("not 0 or 1")
if self.opcode not in ABNF.OPCODES:
raise ValueError("Invalid OPCODE")
length = len(self.data)
if length >= ABNF.LENGTH_63:
raise ValueError("data is too long")
frame_header = chr(
self.fin << 7
| self.rsv1 << 6
| self.rsv2 << 5
| self.rsv3 << 4
| self.opcode
).encode("latin-1")
if length < ABNF.LENGTH_7:
frame_header += chr(self.mask_value << 7 | length).encode("latin-1")
elif length < ABNF.LENGTH_16:
frame_header += chr(self.mask_value << 7 | 0x7E).encode("latin-1")
frame_header += struct.pack("!H", length)
else:
frame_header += chr(self.mask_value << 7 | 0x7F).encode("latin-1")
frame_header += struct.pack("!Q", length)
if not self.mask_value:
if isinstance(self.data, str):
self.data = self.data.encode("utf-8")
return frame_header + self.data
mask_key = self.get_mask_key(4)
return frame_header + self._get_masked(mask_key)
def _get_masked(self, mask_key: Union[str, bytes]) -> bytes:
s = ABNF.mask(mask_key, self.data)
if isinstance(mask_key, str):
mask_key = mask_key.encode("utf-8")
return mask_key + s
@staticmethod
def mask(mask_key: Union[str, bytes], data: Union[str, bytes]) -> bytes:
"""
Mask or unmask data. Just do xor for each byte
Parameters
----------
mask_key: bytes or str
4 byte mask.
data: bytes or str
data to mask/unmask.
"""
if data is None:
data = ""
if isinstance(mask_key, str):
mask_key = mask_key.encode("latin-1")
if isinstance(data, str):
data = data.encode("latin-1")
return _mask(array.array("B", mask_key), array.array("B", data))
class frame_buffer:
_HEADER_MASK_INDEX = 5
_HEADER_LENGTH_INDEX = 6
def __init__(
self, recv_fn: Callable[[int], int], skip_utf8_validation: bool
) -> None:
self.recv = recv_fn
self.skip_utf8_validation = skip_utf8_validation
# Buffers over the packets from the layer beneath until desired amount
# bytes of bytes are received.
self.recv_buffer: list = []
self.clear()
self.lock = Lock()
def clear(self) -> None:
self.header: Optional[tuple] = None
self.length: Optional[int] = None
self.mask_value: Union[bytes, str, None] = None
def has_received_header(self) -> bool:
return self.header is None
def recv_header(self) -> None:
header = self.recv_strict(2)
b1 = header[0]
fin = b1 >> 7 & 1
rsv1 = b1 >> 6 & 1
rsv2 = b1 >> 5 & 1
rsv3 = b1 >> 4 & 1
opcode = b1 & 0xF
b2 = header[1]
has_mask = b2 >> 7 & 1
length_bits = b2 & 0x7F
self.header = (fin, rsv1, rsv2, rsv3, opcode, has_mask, length_bits)
def has_mask(self) -> Union[bool, int]:
if not self.header:
return False
header_val: int = self.header[frame_buffer._HEADER_MASK_INDEX]
return header_val
def has_received_length(self) -> bool:
return self.length is None
def recv_length(self) -> None:
bits = self.header[frame_buffer._HEADER_LENGTH_INDEX]
length_bits = bits & 0x7F
if length_bits == 0x7E:
v = self.recv_strict(2)
self.length = struct.unpack("!H", v)[0]
elif length_bits == 0x7F:
v = self.recv_strict(8)
self.length = struct.unpack("!Q", v)[0]
else:
self.length = length_bits
def has_received_mask(self) -> bool:
return self.mask_value is None
def recv_mask(self) -> None:
self.mask_value = self.recv_strict(4) if self.has_mask() else ""
def recv_frame(self) -> ABNF:
with self.lock:
# Header
if self.has_received_header():
self.recv_header()
(fin, rsv1, rsv2, rsv3, opcode, has_mask, _) = self.header
# Frame length
if self.has_received_length():
self.recv_length()
length = self.length
# Mask
if self.has_received_mask():
self.recv_mask()
mask_value = self.mask_value
# Payload
payload = self.recv_strict(length)
if has_mask:
payload = ABNF.mask(mask_value, payload)
# Reset for next frame
self.clear()
frame = ABNF(fin, rsv1, rsv2, rsv3, opcode, has_mask, payload)
frame.validate(self.skip_utf8_validation)
return frame
def recv_strict(self, bufsize: int) -> bytes:
shortage = bufsize - sum(map(len, self.recv_buffer))
while shortage > 0:
# Limit buffer size that we pass to socket.recv() to avoid
# fragmenting the heap -- the number of bytes recv() actually
# reads is limited by socket buffer and is relatively small,
# yet passing large numbers repeatedly causes lots of large
# buffers allocated and then shrunk, which results in
# fragmentation.
bytes_ = self.recv(min(16384, shortage))
self.recv_buffer.append(bytes_)
shortage -= len(bytes_)
unified = b"".join(self.recv_buffer)
if shortage == 0:
self.recv_buffer = []
return unified
else:
self.recv_buffer = [unified[bufsize:]]
return unified[:bufsize]
class continuous_frame:
def __init__(self, fire_cont_frame: bool, skip_utf8_validation: bool) -> None:
self.fire_cont_frame = fire_cont_frame
self.skip_utf8_validation = skip_utf8_validation
self.cont_data: Optional[list] = None
self.recving_frames: Optional[int] = None
def validate(self, frame: ABNF) -> None:
if not self.recving_frames and frame.opcode == ABNF.OPCODE_CONT:
raise WebSocketProtocolException("Illegal frame")
if self.recving_frames and frame.opcode in (
ABNF.OPCODE_TEXT,
ABNF.OPCODE_BINARY,
):
raise WebSocketProtocolException("Illegal frame")
def add(self, frame: ABNF) -> None:
if self.cont_data:
self.cont_data[1] += frame.data
else:
if frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY):
self.recving_frames = frame.opcode
self.cont_data = [frame.opcode, frame.data]
if frame.fin:
self.recving_frames = None
def is_fire(self, frame: ABNF) -> Union[bool, int]:
return frame.fin or self.fire_cont_frame
def extract(self, frame: ABNF) -> tuple:
data = self.cont_data
self.cont_data = None
frame.data = data[1]
if (
not self.fire_cont_frame
and data[0] == ABNF.OPCODE_TEXT
and not self.skip_utf8_validation
and not validate_utf8(frame.data)
):
raise WebSocketPayloadException(f"cannot decode: {repr(frame.data)}")
return data[0], frame

View File

@@ -0,0 +1,677 @@
import inspect
import selectors
import socket
import threading
import time
from typing import Any, Callable, Optional, Union
from . import _logging
from ._abnf import ABNF
from ._core import WebSocket, getdefaulttimeout
from ._exceptions import (
WebSocketConnectionClosedException,
WebSocketException,
WebSocketTimeoutException,
)
from ._ssl_compat import SSLEOFError
from ._url import parse_url
"""
_app.py
websocket - WebSocket client library for Python
Copyright 2024 engn33r
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
__all__ = ["WebSocketApp"]
RECONNECT = 0
def setReconnect(reconnectInterval: int) -> None:
global RECONNECT
RECONNECT = reconnectInterval
class DispatcherBase:
"""
DispatcherBase
"""
def __init__(self, app: Any, ping_timeout: Union[float, int, None]) -> None:
self.app = app
self.ping_timeout = ping_timeout
def timeout(self, seconds: Union[float, int, None], callback: Callable) -> None:
time.sleep(seconds)
callback()
def reconnect(self, seconds: int, reconnector: Callable) -> None:
try:
_logging.info(
f"reconnect() - retrying in {seconds} seconds [{len(inspect.stack())} frames in stack]"
)
time.sleep(seconds)
reconnector(reconnecting=True)
except KeyboardInterrupt as e:
_logging.info(f"User exited {e}")
raise e
class Dispatcher(DispatcherBase):
"""
Dispatcher
"""
def read(
self,
sock: socket.socket,
read_callback: Callable,
check_callback: Callable,
) -> None:
sel = selectors.DefaultSelector()
sel.register(self.app.sock.sock, selectors.EVENT_READ)
try:
while self.app.keep_running:
if sel.select(self.ping_timeout):
if not read_callback():
break
check_callback()
finally:
sel.close()
class SSLDispatcher(DispatcherBase):
"""
SSLDispatcher
"""
def read(
self,
sock: socket.socket,
read_callback: Callable,
check_callback: Callable,
) -> None:
sock = self.app.sock.sock
sel = selectors.DefaultSelector()
sel.register(sock, selectors.EVENT_READ)
try:
while self.app.keep_running:
if self.select(sock, sel):
if not read_callback():
break
check_callback()
finally:
sel.close()
def select(self, sock, sel: selectors.DefaultSelector):
sock = self.app.sock.sock
if sock.pending():
return [
sock,
]
r = sel.select(self.ping_timeout)
if len(r) > 0:
return r[0][0]
class WrappedDispatcher:
"""
WrappedDispatcher
"""
def __init__(self, app, ping_timeout: Union[float, int, None], dispatcher) -> None:
self.app = app
self.ping_timeout = ping_timeout
self.dispatcher = dispatcher
dispatcher.signal(2, dispatcher.abort) # keyboard interrupt
def read(
self,
sock: socket.socket,
read_callback: Callable,
check_callback: Callable,
) -> None:
self.dispatcher.read(sock, read_callback)
self.ping_timeout and self.timeout(self.ping_timeout, check_callback)
def timeout(self, seconds: float, callback: Callable) -> None:
self.dispatcher.timeout(seconds, callback)
def reconnect(self, seconds: int, reconnector: Callable) -> None:
self.timeout(seconds, reconnector)
class WebSocketApp:
"""
Higher level of APIs are provided. The interface is like JavaScript WebSocket object.
"""
def __init__(
self,
url: str,
header: Union[list, dict, Callable, None] = None,
on_open: Optional[Callable[[WebSocket], None]] = None,
on_reconnect: Optional[Callable[[WebSocket], None]] = None,
on_message: Optional[Callable[[WebSocket, Any], None]] = None,
on_error: Optional[Callable[[WebSocket, Any], None]] = None,
on_close: Optional[Callable[[WebSocket, Any, Any], None]] = None,
on_ping: Optional[Callable] = None,
on_pong: Optional[Callable] = None,
on_cont_message: Optional[Callable] = None,
keep_running: bool = True,
get_mask_key: Optional[Callable] = None,
cookie: Optional[str] = None,
subprotocols: Optional[list] = None,
on_data: Optional[Callable] = None,
socket: Optional[socket.socket] = None,
) -> None:
"""
WebSocketApp initialization
Parameters
----------
url: str
Websocket url.
header: list or dict or Callable
Custom header for websocket handshake.
If the parameter is a callable object, it is called just before the connection attempt.
The returned dict or list is used as custom header value.
This could be useful in order to properly setup timestamp dependent headers.
on_open: function
Callback object which is called at opening websocket.
on_open has one argument.
The 1st argument is this class object.
on_reconnect: function
Callback object which is called at reconnecting websocket.
on_reconnect has one argument.
The 1st argument is this class object.
on_message: function
Callback object which is called when received data.
on_message has 2 arguments.
The 1st argument is this class object.
The 2nd argument is utf-8 data received from the server.
on_error: function
Callback object which is called when we get error.
on_error has 2 arguments.
The 1st argument is this class object.
The 2nd argument is exception object.
on_close: function
Callback object which is called when connection is closed.
on_close has 3 arguments.
The 1st argument is this class object.
The 2nd argument is close_status_code.
The 3rd argument is close_msg.
on_cont_message: function
Callback object which is called when a continuation
frame is received.
on_cont_message has 3 arguments.
The 1st argument is this class object.
The 2nd argument is utf-8 string which we get from the server.
The 3rd argument is continue flag. if 0, the data continue
to next frame data
on_data: function
Callback object which is called when a message received.
This is called before on_message or on_cont_message,
and then on_message or on_cont_message is called.
on_data has 4 argument.
The 1st argument is this class object.
The 2nd argument is utf-8 string which we get from the server.
The 3rd argument is data type. ABNF.OPCODE_TEXT or ABNF.OPCODE_BINARY will be came.
The 4th argument is continue flag. If 0, the data continue
keep_running: bool
This parameter is obsolete and ignored.
get_mask_key: function
A callable function to get new mask keys, see the
WebSocket.set_mask_key's docstring for more information.
cookie: str
Cookie value.
subprotocols: list
List of available sub protocols. Default is None.
socket: socket
Pre-initialized stream socket.
"""
self.url = url
self.header = header if header is not None else []
self.cookie = cookie
self.on_open = on_open
self.on_reconnect = on_reconnect
self.on_message = on_message
self.on_data = on_data
self.on_error = on_error
self.on_close = on_close
self.on_ping = on_ping
self.on_pong = on_pong
self.on_cont_message = on_cont_message
self.keep_running = False
self.get_mask_key = get_mask_key
self.sock: Optional[WebSocket] = None
self.last_ping_tm = float(0)
self.last_pong_tm = float(0)
self.ping_thread: Optional[threading.Thread] = None
self.stop_ping: Optional[threading.Event] = None
self.ping_interval = float(0)
self.ping_timeout: Union[float, int, None] = None
self.ping_payload = ""
self.subprotocols = subprotocols
self.prepared_socket = socket
self.has_errored = False
self.has_done_teardown = False
self.has_done_teardown_lock = threading.Lock()
def send(self, data: Union[bytes, str], opcode: int = ABNF.OPCODE_TEXT) -> None:
"""
send message
Parameters
----------
data: str
Message to send. If you set opcode to OPCODE_TEXT,
data must be utf-8 string or unicode.
opcode: int
Operation code of data. Default is OPCODE_TEXT.
"""
if not self.sock or self.sock.send(data, opcode) == 0:
raise WebSocketConnectionClosedException("Connection is already closed.")
def send_text(self, text_data: str) -> None:
"""
Sends UTF-8 encoded text.
"""
if not self.sock or self.sock.send(text_data, ABNF.OPCODE_TEXT) == 0:
raise WebSocketConnectionClosedException("Connection is already closed.")
def send_bytes(self, data: Union[bytes, bytearray]) -> None:
"""
Sends a sequence of bytes.
"""
if not self.sock or self.sock.send(data, ABNF.OPCODE_BINARY) == 0:
raise WebSocketConnectionClosedException("Connection is already closed.")
def close(self, **kwargs) -> None:
"""
Close websocket connection.
"""
self.keep_running = False
if self.sock:
self.sock.close(**kwargs)
self.sock = None
def _start_ping_thread(self) -> None:
self.last_ping_tm = self.last_pong_tm = float(0)
self.stop_ping = threading.Event()
self.ping_thread = threading.Thread(target=self._send_ping)
self.ping_thread.daemon = True
self.ping_thread.start()
def _stop_ping_thread(self) -> None:
if self.stop_ping:
self.stop_ping.set()
if self.ping_thread and self.ping_thread.is_alive():
self.ping_thread.join(3)
self.last_ping_tm = self.last_pong_tm = float(0)
def _send_ping(self) -> None:
if self.stop_ping.wait(self.ping_interval) or self.keep_running is False:
return
while not self.stop_ping.wait(self.ping_interval) and self.keep_running is True:
if self.sock:
self.last_ping_tm = time.time()
try:
_logging.debug("Sending ping")
self.sock.ping(self.ping_payload)
except Exception as e:
_logging.debug(f"Failed to send ping: {e}")
def run_forever(
self,
sockopt: tuple = None,
sslopt: dict = None,
ping_interval: Union[float, int] = 0,
ping_timeout: Union[float, int, None] = None,
ping_payload: str = "",
http_proxy_host: str = None,
http_proxy_port: Union[int, str] = None,
http_no_proxy: list = None,
http_proxy_auth: tuple = None,
http_proxy_timeout: Optional[float] = None,
skip_utf8_validation: bool = False,
host: str = None,
origin: str = None,
dispatcher=None,
suppress_origin: bool = False,
proxy_type: str = None,
reconnect: int = None,
) -> bool:
"""
Run event loop for WebSocket framework.
This loop is an infinite loop and is alive while websocket is available.
Parameters
----------
sockopt: tuple
Values for socket.setsockopt.
sockopt must be tuple
and each element is argument of sock.setsockopt.
sslopt: dict
Optional dict object for ssl socket option.
ping_interval: int or float
Automatically send "ping" command
every specified period (in seconds).
If set to 0, no ping is sent periodically.
ping_timeout: int or float
Timeout (in seconds) if the pong message is not received.
ping_payload: str
Payload message to send with each ping.
http_proxy_host: str
HTTP proxy host name.
http_proxy_port: int or str
HTTP proxy port. If not set, set to 80.
http_no_proxy: list
Whitelisted host names that don't use the proxy.
http_proxy_timeout: int or float
HTTP proxy timeout, default is 60 sec as per python-socks.
http_proxy_auth: tuple
HTTP proxy auth information. tuple of username and password. Default is None.
skip_utf8_validation: bool
skip utf8 validation.
host: str
update host header.
origin: str
update origin header.
dispatcher: Dispatcher object
customize reading data from socket.
suppress_origin: bool
suppress outputting origin header.
proxy_type: str
type of proxy from: http, socks4, socks4a, socks5, socks5h
reconnect: int
delay interval when reconnecting
Returns
-------
teardown: bool
False if the `WebSocketApp` is closed or caught KeyboardInterrupt,
True if any other exception was raised during a loop.
"""
if reconnect is None:
reconnect = RECONNECT
if ping_timeout is not None and ping_timeout <= 0:
raise WebSocketException("Ensure ping_timeout > 0")
if ping_interval is not None and ping_interval < 0:
raise WebSocketException("Ensure ping_interval >= 0")
if ping_timeout and ping_interval and ping_interval <= ping_timeout:
raise WebSocketException("Ensure ping_interval > ping_timeout")
if not sockopt:
sockopt = ()
if not sslopt:
sslopt = {}
if self.sock:
raise WebSocketException("socket is already opened")
self.ping_interval = ping_interval
self.ping_timeout = ping_timeout
self.ping_payload = ping_payload
self.has_done_teardown = False
self.keep_running = True
def teardown(close_frame: ABNF = None):
"""
Tears down the connection.
Parameters
----------
close_frame: ABNF frame
If close_frame is set, the on_close handler is invoked
with the statusCode and reason from the provided frame.
"""
# teardown() is called in many code paths to ensure resources are cleaned up and on_close is fired.
# To ensure the work is only done once, we use this bool and lock.
with self.has_done_teardown_lock:
if self.has_done_teardown:
return
self.has_done_teardown = True
self._stop_ping_thread()
self.keep_running = False
if self.sock:
self.sock.close()
close_status_code, close_reason = self._get_close_args(
close_frame if close_frame else None
)
self.sock = None
# Finally call the callback AFTER all teardown is complete
self._callback(self.on_close, close_status_code, close_reason)
def setSock(reconnecting: bool = False) -> None:
if reconnecting and self.sock:
self.sock.shutdown()
self.sock = WebSocket(
self.get_mask_key,
sockopt=sockopt,
sslopt=sslopt,
fire_cont_frame=self.on_cont_message is not None,
skip_utf8_validation=skip_utf8_validation,
enable_multithread=True,
)
self.sock.settimeout(getdefaulttimeout())
try:
header = self.header() if callable(self.header) else self.header
self.sock.connect(
self.url,
header=header,
cookie=self.cookie,
http_proxy_host=http_proxy_host,
http_proxy_port=http_proxy_port,
http_no_proxy=http_no_proxy,
http_proxy_auth=http_proxy_auth,
http_proxy_timeout=http_proxy_timeout,
subprotocols=self.subprotocols,
host=host,
origin=origin,
suppress_origin=suppress_origin,
proxy_type=proxy_type,
socket=self.prepared_socket,
)
_logging.info("Websocket connected")
if self.ping_interval:
self._start_ping_thread()
if reconnecting and self.on_reconnect:
self._callback(self.on_reconnect)
else:
self._callback(self.on_open)
dispatcher.read(self.sock.sock, read, check)
except (
WebSocketConnectionClosedException,
ConnectionRefusedError,
KeyboardInterrupt,
SystemExit,
Exception,
) as e:
handleDisconnect(e, reconnecting)
def read() -> bool:
if not self.keep_running:
return teardown()
try:
op_code, frame = self.sock.recv_data_frame(True)
except (
WebSocketConnectionClosedException,
KeyboardInterrupt,
SSLEOFError,
) as e:
if custom_dispatcher:
return handleDisconnect(e, bool(reconnect))
else:
raise e
if op_code == ABNF.OPCODE_CLOSE:
return teardown(frame)
elif op_code == ABNF.OPCODE_PING:
self._callback(self.on_ping, frame.data)
elif op_code == ABNF.OPCODE_PONG:
self.last_pong_tm = time.time()
self._callback(self.on_pong, frame.data)
elif op_code == ABNF.OPCODE_CONT and self.on_cont_message:
self._callback(self.on_data, frame.data, frame.opcode, frame.fin)
self._callback(self.on_cont_message, frame.data, frame.fin)
else:
data = frame.data
if op_code == ABNF.OPCODE_TEXT and not skip_utf8_validation:
data = data.decode("utf-8")
self._callback(self.on_data, data, frame.opcode, True)
self._callback(self.on_message, data)
return True
def check() -> bool:
if self.ping_timeout:
has_timeout_expired = (
time.time() - self.last_ping_tm > self.ping_timeout
)
has_pong_not_arrived_after_last_ping = (
self.last_pong_tm - self.last_ping_tm < 0
)
has_pong_arrived_too_late = (
self.last_pong_tm - self.last_ping_tm > self.ping_timeout
)
if (
self.last_ping_tm
and has_timeout_expired
and (
has_pong_not_arrived_after_last_ping
or has_pong_arrived_too_late
)
):
raise WebSocketTimeoutException("ping/pong timed out")
return True
def handleDisconnect(
e: Union[
WebSocketConnectionClosedException,
ConnectionRefusedError,
KeyboardInterrupt,
SystemExit,
Exception,
],
reconnecting: bool = False,
) -> bool:
self.has_errored = True
self._stop_ping_thread()
if not reconnecting:
self._callback(self.on_error, e)
if isinstance(e, (KeyboardInterrupt, SystemExit)):
teardown()
# Propagate further
raise
if reconnect:
_logging.info(f"{e} - reconnect")
if custom_dispatcher:
_logging.debug(
f"Calling custom dispatcher reconnect [{len(inspect.stack())} frames in stack]"
)
dispatcher.reconnect(reconnect, setSock)
else:
_logging.error(f"{e} - goodbye")
teardown()
custom_dispatcher = bool(dispatcher)
dispatcher = self.create_dispatcher(
ping_timeout, dispatcher, parse_url(self.url)[3]
)
try:
setSock()
if not custom_dispatcher and reconnect:
while self.keep_running:
_logging.debug(
f"Calling dispatcher reconnect [{len(inspect.stack())} frames in stack]"
)
dispatcher.reconnect(reconnect, setSock)
except (KeyboardInterrupt, Exception) as e:
_logging.info(f"tearing down on exception {e}")
teardown()
finally:
if not custom_dispatcher:
# Ensure teardown was called before returning from run_forever
teardown()
return self.has_errored
def create_dispatcher(
self,
ping_timeout: Union[float, int, None],
dispatcher: Optional[DispatcherBase] = None,
is_ssl: bool = False,
) -> Union[Dispatcher, SSLDispatcher, WrappedDispatcher]:
if dispatcher: # If custom dispatcher is set, use WrappedDispatcher
return WrappedDispatcher(self, ping_timeout, dispatcher)
timeout = ping_timeout or 10
if is_ssl:
return SSLDispatcher(self, timeout)
return Dispatcher(self, timeout)
def _get_close_args(self, close_frame: ABNF) -> list:
"""
_get_close_args extracts the close code and reason from the close body
if it exists (RFC6455 says WebSocket Connection Close Code is optional)
"""
# Need to catch the case where close_frame is None
# Otherwise the following if statement causes an error
if not self.on_close or not close_frame:
return [None, None]
# Extract close frame status code
if close_frame.data and len(close_frame.data) >= 2:
close_status_code = 256 * int(close_frame.data[0]) + int(
close_frame.data[1]
)
reason = close_frame.data[2:]
if isinstance(reason, bytes):
reason = reason.decode("utf-8")
return [close_status_code, reason]
else:
# Most likely reached this because len(close_frame_data.data) < 2
return [None, None]
def _callback(self, callback, *args) -> None:
if callback:
try:
callback(self, *args)
except Exception as e:
_logging.error(f"error from callback {callback}: {e}")
if self.on_error:
self.on_error(self, e)

View File

@@ -0,0 +1,75 @@
import http.cookies
from typing import Optional
"""
_cookiejar.py
websocket - WebSocket client library for Python
Copyright 2024 engn33r
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
class SimpleCookieJar:
def __init__(self) -> None:
self.jar: dict = {}
def add(self, set_cookie: Optional[str]) -> None:
if set_cookie:
simple_cookie = http.cookies.SimpleCookie(set_cookie)
for v in simple_cookie.values():
if domain := v.get("domain"):
if not domain.startswith("."):
domain = f".{domain}"
cookie = (
self.jar.get(domain)
if self.jar.get(domain)
else http.cookies.SimpleCookie()
)
cookie.update(simple_cookie)
self.jar[domain.lower()] = cookie
def set(self, set_cookie: str) -> None:
if set_cookie:
simple_cookie = http.cookies.SimpleCookie(set_cookie)
for v in simple_cookie.values():
if domain := v.get("domain"):
if not domain.startswith("."):
domain = f".{domain}"
self.jar[domain.lower()] = simple_cookie
def get(self, host: str) -> str:
if not host:
return ""
cookies = []
for domain, _ in self.jar.items():
host = host.lower()
if host.endswith(domain) or host == domain[1:]:
cookies.append(self.jar.get(domain))
return "; ".join(
filter(
None,
sorted(
[
f"{k}={v.value}"
for cookie in filter(None, cookies)
for k, v in cookie.items()
]
),
)
)

View File

@@ -0,0 +1,647 @@
import socket
import struct
import threading
import time
from typing import Optional, Union
# websocket modules
from ._abnf import ABNF, STATUS_NORMAL, continuous_frame, frame_buffer
from ._exceptions import WebSocketProtocolException, WebSocketConnectionClosedException
from ._handshake import SUPPORTED_REDIRECT_STATUSES, handshake
from ._http import connect, proxy_info
from ._logging import debug, error, trace, isEnabledForError, isEnabledForTrace
from ._socket import getdefaulttimeout, recv, send, sock_opt
from ._ssl_compat import ssl
from ._utils import NoLock
"""
_core.py
websocket - WebSocket client library for Python
Copyright 2024 engn33r
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
__all__ = ["WebSocket", "create_connection"]
class WebSocket:
"""
Low level WebSocket interface.
This class is based on the WebSocket protocol `draft-hixie-thewebsocketprotocol-76 <http://tools.ietf.org/html/draft-hixie-thewebsocketprotocol-76>`_
We can connect to the websocket server and send/receive data.
The following example is an echo client.
>>> import websocket
>>> ws = websocket.WebSocket()
>>> ws.connect("ws://echo.websocket.events")
>>> ws.recv()
'echo.websocket.events sponsored by Lob.com'
>>> ws.send("Hello, Server")
19
>>> ws.recv()
'Hello, Server'
>>> ws.close()
Parameters
----------
get_mask_key: func
A callable function to get new mask keys, see the
WebSocket.set_mask_key's docstring for more information.
sockopt: tuple
Values for socket.setsockopt.
sockopt must be tuple and each element is argument of sock.setsockopt.
sslopt: dict
Optional dict object for ssl socket options. See FAQ for details.
fire_cont_frame: bool
Fire recv event for each cont frame. Default is False.
enable_multithread: bool
If set to True, lock send method.
skip_utf8_validation: bool
Skip utf8 validation.
"""
def __init__(
self,
get_mask_key=None,
sockopt=None,
sslopt=None,
fire_cont_frame: bool = False,
enable_multithread: bool = True,
skip_utf8_validation: bool = False,
**_,
):
"""
Initialize WebSocket object.
Parameters
----------
sslopt: dict
Optional dict object for ssl socket options. See FAQ for details.
"""
self.sock_opt = sock_opt(sockopt, sslopt)
self.handshake_response = None
self.sock: Optional[socket.socket] = None
self.connected = False
self.get_mask_key = get_mask_key
# These buffer over the build-up of a single frame.
self.frame_buffer = frame_buffer(self._recv, skip_utf8_validation)
self.cont_frame = continuous_frame(fire_cont_frame, skip_utf8_validation)
if enable_multithread:
self.lock = threading.Lock()
self.readlock = threading.Lock()
else:
self.lock = NoLock()
self.readlock = NoLock()
def __iter__(self):
"""
Allow iteration over websocket, implying sequential `recv` executions.
"""
while True:
yield self.recv()
def __next__(self):
return self.recv()
def next(self):
return self.__next__()
def fileno(self):
return self.sock.fileno()
def set_mask_key(self, func):
"""
Set function to create mask key. You can customize mask key generator.
Mainly, this is for testing purpose.
Parameters
----------
func: func
callable object. the func takes 1 argument as integer.
The argument means length of mask key.
This func must return string(byte array),
which length is argument specified.
"""
self.get_mask_key = func
def gettimeout(self) -> Union[float, int, None]:
"""
Get the websocket timeout (in seconds) as an int or float
Returns
----------
timeout: int or float
returns timeout value (in seconds). This value could be either float/integer.
"""
return self.sock_opt.timeout
def settimeout(self, timeout: Union[float, int, None]):
"""
Set the timeout to the websocket.
Parameters
----------
timeout: int or float
timeout time (in seconds). This value could be either float/integer.
"""
self.sock_opt.timeout = timeout
if self.sock:
self.sock.settimeout(timeout)
timeout = property(gettimeout, settimeout)
def getsubprotocol(self):
"""
Get subprotocol
"""
if self.handshake_response:
return self.handshake_response.subprotocol
else:
return None
subprotocol = property(getsubprotocol)
def getstatus(self):
"""
Get handshake status
"""
if self.handshake_response:
return self.handshake_response.status
else:
return None
status = property(getstatus)
def getheaders(self):
"""
Get handshake response header
"""
if self.handshake_response:
return self.handshake_response.headers
else:
return None
def is_ssl(self):
try:
return isinstance(self.sock, ssl.SSLSocket)
except:
return False
headers = property(getheaders)
def connect(self, url, **options):
"""
Connect to url. url is websocket url scheme.
ie. ws://host:port/resource
You can customize using 'options'.
If you set "header" list object, you can set your own custom header.
>>> ws = WebSocket()
>>> ws.connect("ws://echo.websocket.events",
... header=["User-Agent: MyProgram",
... "x-custom: header"])
Parameters
----------
header: list or dict
Custom http header list or dict.
cookie: str
Cookie value.
origin: str
Custom origin url.
connection: str
Custom connection header value.
Default value "Upgrade" set in _handshake.py
suppress_origin: bool
Suppress outputting origin header.
host: str
Custom host header string.
timeout: int or float
Socket timeout time. This value is an integer or float.
If you set None for this value, it means "use default_timeout value"
http_proxy_host: str
HTTP proxy host name.
http_proxy_port: str or int
HTTP proxy port. Default is 80.
http_no_proxy: list
Whitelisted host names that don't use the proxy.
http_proxy_auth: tuple
HTTP proxy auth information. Tuple of username and password. Default is None.
http_proxy_timeout: int or float
HTTP proxy timeout, default is 60 sec as per python-socks.
redirect_limit: int
Number of redirects to follow.
subprotocols: list
List of available subprotocols. Default is None.
socket: socket
Pre-initialized stream socket.
"""
self.sock_opt.timeout = options.get("timeout", self.sock_opt.timeout)
self.sock, addrs = connect(
url, self.sock_opt, proxy_info(**options), options.pop("socket", None)
)
try:
self.handshake_response = handshake(self.sock, url, *addrs, **options)
for _ in range(options.pop("redirect_limit", 3)):
if self.handshake_response.status in SUPPORTED_REDIRECT_STATUSES:
url = self.handshake_response.headers["location"]
self.sock.close()
self.sock, addrs = connect(
url,
self.sock_opt,
proxy_info(**options),
options.pop("socket", None),
)
self.handshake_response = handshake(
self.sock, url, *addrs, **options
)
self.connected = True
except:
if self.sock:
self.sock.close()
self.sock = None
raise
def send(self, payload: Union[bytes, str], opcode: int = ABNF.OPCODE_TEXT) -> int:
"""
Send the data as string.
Parameters
----------
payload: str
Payload must be utf-8 string or unicode,
If the opcode is OPCODE_TEXT.
Otherwise, it must be string(byte array).
opcode: int
Operation code (opcode) to send.
"""
frame = ABNF.create_frame(payload, opcode)
return self.send_frame(frame)
def send_text(self, text_data: str) -> int:
"""
Sends UTF-8 encoded text.
"""
return self.send(text_data, ABNF.OPCODE_TEXT)
def send_bytes(self, data: Union[bytes, bytearray]) -> int:
"""
Sends a sequence of bytes.
"""
return self.send(data, ABNF.OPCODE_BINARY)
def send_frame(self, frame) -> int:
"""
Send the data frame.
>>> ws = create_connection("ws://echo.websocket.events")
>>> frame = ABNF.create_frame("Hello", ABNF.OPCODE_TEXT)
>>> ws.send_frame(frame)
>>> cont_frame = ABNF.create_frame("My name is ", ABNF.OPCODE_CONT, 0)
>>> ws.send_frame(frame)
>>> cont_frame = ABNF.create_frame("Foo Bar", ABNF.OPCODE_CONT, 1)
>>> ws.send_frame(frame)
Parameters
----------
frame: ABNF frame
frame data created by ABNF.create_frame
"""
if self.get_mask_key:
frame.get_mask_key = self.get_mask_key
data = frame.format()
length = len(data)
if isEnabledForTrace():
trace(f"++Sent raw: {repr(data)}")
trace(f"++Sent decoded: {frame.__str__()}")
with self.lock:
while data:
l = self._send(data)
data = data[l:]
return length
def send_binary(self, payload: bytes) -> int:
"""
Send a binary message (OPCODE_BINARY).
Parameters
----------
payload: bytes
payload of message to send.
"""
return self.send(payload, ABNF.OPCODE_BINARY)
def ping(self, payload: Union[str, bytes] = ""):
"""
Send ping data.
Parameters
----------
payload: str
data payload to send server.
"""
if isinstance(payload, str):
payload = payload.encode("utf-8")
self.send(payload, ABNF.OPCODE_PING)
def pong(self, payload: Union[str, bytes] = ""):
"""
Send pong data.
Parameters
----------
payload: str
data payload to send server.
"""
if isinstance(payload, str):
payload = payload.encode("utf-8")
self.send(payload, ABNF.OPCODE_PONG)
def recv(self) -> Union[str, bytes]:
"""
Receive string data(byte array) from the server.
Returns
----------
data: string (byte array) value.
"""
with self.readlock:
opcode, data = self.recv_data()
if opcode == ABNF.OPCODE_TEXT:
data_received: Union[bytes, str] = data
if isinstance(data_received, bytes):
return data_received.decode("utf-8")
elif isinstance(data_received, str):
return data_received
elif opcode == ABNF.OPCODE_BINARY:
data_binary: bytes = data
return data_binary
else:
return ""
def recv_data(self, control_frame: bool = False) -> tuple:
"""
Receive data with operation code.
Parameters
----------
control_frame: bool
a boolean flag indicating whether to return control frame
data, defaults to False
Returns
-------
opcode, frame.data: tuple
tuple of operation code and string(byte array) value.
"""
opcode, frame = self.recv_data_frame(control_frame)
return opcode, frame.data
def recv_data_frame(self, control_frame: bool = False) -> tuple:
"""
Receive data with operation code.
If a valid ping message is received, a pong response is sent.
Parameters
----------
control_frame: bool
a boolean flag indicating whether to return control frame
data, defaults to False
Returns
-------
frame.opcode, frame: tuple
tuple of operation code and string(byte array) value.
"""
while True:
frame = self.recv_frame()
if isEnabledForTrace():
trace(f"++Rcv raw: {repr(frame.format())}")
trace(f"++Rcv decoded: {frame.__str__()}")
if not frame:
# handle error:
# 'NoneType' object has no attribute 'opcode'
raise WebSocketProtocolException(f"Not a valid frame {frame}")
elif frame.opcode in (
ABNF.OPCODE_TEXT,
ABNF.OPCODE_BINARY,
ABNF.OPCODE_CONT,
):
self.cont_frame.validate(frame)
self.cont_frame.add(frame)
if self.cont_frame.is_fire(frame):
return self.cont_frame.extract(frame)
elif frame.opcode == ABNF.OPCODE_CLOSE:
self.send_close()
return frame.opcode, frame
elif frame.opcode == ABNF.OPCODE_PING:
if len(frame.data) < 126:
self.pong(frame.data)
else:
raise WebSocketProtocolException("Ping message is too long")
if control_frame:
return frame.opcode, frame
elif frame.opcode == ABNF.OPCODE_PONG:
if control_frame:
return frame.opcode, frame
def recv_frame(self):
"""
Receive data as frame from server.
Returns
-------
self.frame_buffer.recv_frame(): ABNF frame object
"""
return self.frame_buffer.recv_frame()
def send_close(self, status: int = STATUS_NORMAL, reason: bytes = b""):
"""
Send close data to the server.
Parameters
----------
status: int
Status code to send. See STATUS_XXX.
reason: str or bytes
The reason to close. This must be string or UTF-8 bytes.
"""
if status < 0 or status >= ABNF.LENGTH_16:
raise ValueError("code is invalid range")
self.connected = False
self.send(struct.pack("!H", status) + reason, ABNF.OPCODE_CLOSE)
def close(self, status: int = STATUS_NORMAL, reason: bytes = b"", timeout: int = 3):
"""
Close Websocket object
Parameters
----------
status: int
Status code to send. See VALID_CLOSE_STATUS in ABNF.
reason: bytes
The reason to close in UTF-8.
timeout: int or float
Timeout until receive a close frame.
If None, it will wait forever until receive a close frame.
"""
if not self.connected:
return
if status < 0 or status >= ABNF.LENGTH_16:
raise ValueError("code is invalid range")
try:
self.connected = False
self.send(struct.pack("!H", status) + reason, ABNF.OPCODE_CLOSE)
sock_timeout = self.sock.gettimeout()
self.sock.settimeout(timeout)
start_time = time.time()
while timeout is None or time.time() - start_time < timeout:
try:
frame = self.recv_frame()
if frame.opcode != ABNF.OPCODE_CLOSE:
continue
if isEnabledForError():
recv_status = struct.unpack("!H", frame.data[0:2])[0]
if recv_status >= 3000 and recv_status <= 4999:
debug(f"close status: {repr(recv_status)}")
elif recv_status != STATUS_NORMAL:
error(f"close status: {repr(recv_status)}")
break
except:
break
self.sock.settimeout(sock_timeout)
self.sock.shutdown(socket.SHUT_RDWR)
except:
pass
self.shutdown()
def abort(self):
"""
Low-level asynchronous abort, wakes up other threads that are waiting in recv_*
"""
if self.connected:
self.sock.shutdown(socket.SHUT_RDWR)
def shutdown(self):
"""
close socket, immediately.
"""
if self.sock:
self.sock.close()
self.sock = None
self.connected = False
def _send(self, data: Union[str, bytes]):
return send(self.sock, data)
def _recv(self, bufsize):
try:
return recv(self.sock, bufsize)
except WebSocketConnectionClosedException:
if self.sock:
self.sock.close()
self.sock = None
self.connected = False
raise
def create_connection(url: str, timeout=None, class_=WebSocket, **options):
"""
Connect to url and return websocket object.
Connect to url and return the WebSocket object.
Passing optional timeout parameter will set the timeout on the socket.
If no timeout is supplied,
the global default timeout setting returned by getdefaulttimeout() is used.
You can customize using 'options'.
If you set "header" list object, you can set your own custom header.
>>> conn = create_connection("ws://echo.websocket.events",
... header=["User-Agent: MyProgram",
... "x-custom: header"])
Parameters
----------
class_: class
class to instantiate when creating the connection. It has to implement
settimeout and connect. It's __init__ should be compatible with
WebSocket.__init__, i.e. accept all of it's kwargs.
header: list or dict
custom http header list or dict.
cookie: str
Cookie value.
origin: str
custom origin url.
suppress_origin: bool
suppress outputting origin header.
host: str
custom host header string.
timeout: int or float
socket timeout time. This value could be either float/integer.
If set to None, it uses the default_timeout value.
http_proxy_host: str
HTTP proxy host name.
http_proxy_port: str or int
HTTP proxy port. If not set, set to 80.
http_no_proxy: list
Whitelisted host names that don't use the proxy.
http_proxy_auth: tuple
HTTP proxy auth information. tuple of username and password. Default is None.
http_proxy_timeout: int or float
HTTP proxy timeout, default is 60 sec as per python-socks.
enable_multithread: bool
Enable lock for multithread.
redirect_limit: int
Number of redirects to follow.
sockopt: tuple
Values for socket.setsockopt.
sockopt must be a tuple and each element is an argument of sock.setsockopt.
sslopt: dict
Optional dict object for ssl socket options. See FAQ for details.
subprotocols: list
List of available subprotocols. Default is None.
skip_utf8_validation: bool
Skip utf8 validation.
socket: socket
Pre-initialized stream socket.
"""
sockopt = options.pop("sockopt", [])
sslopt = options.pop("sslopt", {})
fire_cont_frame = options.pop("fire_cont_frame", False)
enable_multithread = options.pop("enable_multithread", True)
skip_utf8_validation = options.pop("skip_utf8_validation", False)
websock = class_(
sockopt=sockopt,
sslopt=sslopt,
fire_cont_frame=fire_cont_frame,
enable_multithread=enable_multithread,
skip_utf8_validation=skip_utf8_validation,
**options,
)
websock.settimeout(timeout if timeout is not None else getdefaulttimeout())
websock.connect(url, **options)
return websock

View File

@@ -0,0 +1,94 @@
"""
_exceptions.py
websocket - WebSocket client library for Python
Copyright 2024 engn33r
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
class WebSocketException(Exception):
"""
WebSocket exception class.
"""
pass
class WebSocketProtocolException(WebSocketException):
"""
If the WebSocket protocol is invalid, this exception will be raised.
"""
pass
class WebSocketPayloadException(WebSocketException):
"""
If the WebSocket payload is invalid, this exception will be raised.
"""
pass
class WebSocketConnectionClosedException(WebSocketException):
"""
If remote host closed the connection or some network error happened,
this exception will be raised.
"""
pass
class WebSocketTimeoutException(WebSocketException):
"""
WebSocketTimeoutException will be raised at socket timeout during read/write data.
"""
pass
class WebSocketProxyException(WebSocketException):
"""
WebSocketProxyException will be raised when proxy error occurred.
"""
pass
class WebSocketBadStatusException(WebSocketException):
"""
WebSocketBadStatusException will be raised when we get bad handshake status code.
"""
def __init__(
self,
message: str,
status_code: int,
status_message=None,
resp_headers=None,
resp_body=None,
):
super().__init__(message)
self.status_code = status_code
self.resp_headers = resp_headers
self.resp_body = resp_body
class WebSocketAddressException(WebSocketException):
"""
If the websocket address info cannot be found, this exception will be raised.
"""
pass

View File

@@ -0,0 +1,203 @@
"""
_handshake.py
websocket - WebSocket client library for Python
Copyright 2024 engn33r
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import hashlib
import hmac
import os
from base64 import encodebytes as base64encode
from http import HTTPStatus
from ._cookiejar import SimpleCookieJar
from ._exceptions import WebSocketException, WebSocketBadStatusException
from ._http import read_headers
from ._logging import dump, error
from ._socket import send
__all__ = ["handshake_response", "handshake", "SUPPORTED_REDIRECT_STATUSES"]
# websocket supported version.
VERSION = 13
SUPPORTED_REDIRECT_STATUSES = (
HTTPStatus.MOVED_PERMANENTLY,
HTTPStatus.FOUND,
HTTPStatus.SEE_OTHER,
HTTPStatus.TEMPORARY_REDIRECT,
HTTPStatus.PERMANENT_REDIRECT,
)
SUCCESS_STATUSES = SUPPORTED_REDIRECT_STATUSES + (HTTPStatus.SWITCHING_PROTOCOLS,)
CookieJar = SimpleCookieJar()
class handshake_response:
def __init__(self, status: int, headers: dict, subprotocol):
self.status = status
self.headers = headers
self.subprotocol = subprotocol
CookieJar.add(headers.get("set-cookie"))
def handshake(
sock, url: str, hostname: str, port: int, resource: str, **options
) -> handshake_response:
headers, key = _get_handshake_headers(resource, url, hostname, port, options)
header_str = "\r\n".join(headers)
send(sock, header_str)
dump("request header", header_str)
status, resp = _get_resp_headers(sock)
if status in SUPPORTED_REDIRECT_STATUSES:
return handshake_response(status, resp, None)
success, subproto = _validate(resp, key, options.get("subprotocols"))
if not success:
raise WebSocketException("Invalid WebSocket Header")
return handshake_response(status, resp, subproto)
def _pack_hostname(hostname: str) -> str:
# IPv6 address
if ":" in hostname:
return f"[{hostname}]"
return hostname
def _get_handshake_headers(
resource: str, url: str, host: str, port: int, options: dict
) -> tuple:
headers = [f"GET {resource} HTTP/1.1", "Upgrade: websocket"]
if port in [80, 443]:
hostport = _pack_hostname(host)
else:
hostport = f"{_pack_hostname(host)}:{port}"
if options.get("host"):
headers.append(f'Host: {options["host"]}')
else:
headers.append(f"Host: {hostport}")
# scheme indicates whether http or https is used in Origin
# The same approach is used in parse_url of _url.py to set default port
scheme, url = url.split(":", 1)
if not options.get("suppress_origin"):
if "origin" in options and options["origin"] is not None:
headers.append(f'Origin: {options["origin"]}')
elif scheme == "wss":
headers.append(f"Origin: https://{hostport}")
else:
headers.append(f"Origin: http://{hostport}")
key = _create_sec_websocket_key()
# Append Sec-WebSocket-Key & Sec-WebSocket-Version if not manually specified
if not options.get("header") or "Sec-WebSocket-Key" not in options["header"]:
headers.append(f"Sec-WebSocket-Key: {key}")
else:
key = options["header"]["Sec-WebSocket-Key"]
if not options.get("header") or "Sec-WebSocket-Version" not in options["header"]:
headers.append(f"Sec-WebSocket-Version: {VERSION}")
if not options.get("connection"):
headers.append("Connection: Upgrade")
else:
headers.append(options["connection"])
if subprotocols := options.get("subprotocols"):
headers.append(f'Sec-WebSocket-Protocol: {",".join(subprotocols)}')
if header := options.get("header"):
if isinstance(header, dict):
header = [": ".join([k, v]) for k, v in header.items() if v is not None]
headers.extend(header)
server_cookie = CookieJar.get(host)
client_cookie = options.get("cookie", None)
if cookie := "; ".join(filter(None, [server_cookie, client_cookie])):
headers.append(f"Cookie: {cookie}")
headers.extend(("", ""))
return headers, key
def _get_resp_headers(sock, success_statuses: tuple = SUCCESS_STATUSES) -> tuple:
status, resp_headers, status_message = read_headers(sock)
if status not in success_statuses:
content_len = resp_headers.get("content-length")
if content_len:
response_body = sock.recv(
int(content_len)
) # read the body of the HTTP error message response and include it in the exception
else:
response_body = None
raise WebSocketBadStatusException(
f"Handshake status {status} {status_message} -+-+- {resp_headers} -+-+- {response_body}",
status,
status_message,
resp_headers,
response_body,
)
return status, resp_headers
_HEADERS_TO_CHECK = {
"upgrade": "websocket",
"connection": "upgrade",
}
def _validate(headers, key: str, subprotocols) -> tuple:
subproto = None
for k, v in _HEADERS_TO_CHECK.items():
r = headers.get(k, None)
if not r:
return False, None
r = [x.strip().lower() for x in r.split(",")]
if v not in r:
return False, None
if subprotocols:
subproto = headers.get("sec-websocket-protocol", None)
if not subproto or subproto.lower() not in [s.lower() for s in subprotocols]:
error(f"Invalid subprotocol: {subprotocols}")
return False, None
subproto = subproto.lower()
result = headers.get("sec-websocket-accept", None)
if not result:
return False, None
result = result.lower()
if isinstance(result, str):
result = result.encode("utf-8")
value = f"{key}258EAFA5-E914-47DA-95CA-C5AB0DC85B11".encode("utf-8")
hashed = base64encode(hashlib.sha1(value).digest()).strip().lower()
if hmac.compare_digest(hashed, result):
return True, subproto
else:
return False, None
def _create_sec_websocket_key() -> str:
randomness = os.urandom(16)
return base64encode(randomness).decode("utf-8").strip()

View File

@@ -0,0 +1,374 @@
"""
_http.py
websocket - WebSocket client library for Python
Copyright 2024 engn33r
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import errno
import os
import socket
from base64 import encodebytes as base64encode
from ._exceptions import (
WebSocketAddressException,
WebSocketException,
WebSocketProxyException,
)
from ._logging import debug, dump, trace
from ._socket import DEFAULT_SOCKET_OPTION, recv_line, send
from ._ssl_compat import HAVE_SSL, ssl
from ._url import get_proxy_info, parse_url
__all__ = ["proxy_info", "connect", "read_headers"]
try:
from python_socks._errors import *
from python_socks._types import ProxyType
from python_socks.sync import Proxy
HAVE_PYTHON_SOCKS = True
except:
HAVE_PYTHON_SOCKS = False
class ProxyError(Exception):
pass
class ProxyTimeoutError(Exception):
pass
class ProxyConnectionError(Exception):
pass
class proxy_info:
def __init__(self, **options):
self.proxy_host = options.get("http_proxy_host", None)
if self.proxy_host:
self.proxy_port = options.get("http_proxy_port", 0)
self.auth = options.get("http_proxy_auth", None)
self.no_proxy = options.get("http_no_proxy", None)
self.proxy_protocol = options.get("proxy_type", "http")
# Note: If timeout not specified, default python-socks timeout is 60 seconds
self.proxy_timeout = options.get("http_proxy_timeout", None)
if self.proxy_protocol not in [
"http",
"socks4",
"socks4a",
"socks5",
"socks5h",
]:
raise ProxyError(
"Only http, socks4, socks5 proxy protocols are supported"
)
else:
self.proxy_port = 0
self.auth = None
self.no_proxy = None
self.proxy_protocol = "http"
def _start_proxied_socket(url: str, options, proxy) -> tuple:
if not HAVE_PYTHON_SOCKS:
raise WebSocketException(
"Python Socks is needed for SOCKS proxying but is not available"
)
hostname, port, resource, is_secure = parse_url(url)
if proxy.proxy_protocol == "socks4":
rdns = False
proxy_type = ProxyType.SOCKS4
# socks4a sends DNS through proxy
elif proxy.proxy_protocol == "socks4a":
rdns = True
proxy_type = ProxyType.SOCKS4
elif proxy.proxy_protocol == "socks5":
rdns = False
proxy_type = ProxyType.SOCKS5
# socks5h sends DNS through proxy
elif proxy.proxy_protocol == "socks5h":
rdns = True
proxy_type = ProxyType.SOCKS5
ws_proxy = Proxy.create(
proxy_type=proxy_type,
host=proxy.proxy_host,
port=int(proxy.proxy_port),
username=proxy.auth[0] if proxy.auth else None,
password=proxy.auth[1] if proxy.auth else None,
rdns=rdns,
)
sock = ws_proxy.connect(hostname, port, timeout=proxy.proxy_timeout)
if is_secure:
if HAVE_SSL:
sock = _ssl_socket(sock, options.sslopt, hostname)
else:
raise WebSocketException("SSL not available.")
return sock, (hostname, port, resource)
def connect(url: str, options, proxy, socket):
# Use _start_proxied_socket() only for socks4 or socks5 proxy
# Use _tunnel() for http proxy
# TODO: Use python-socks for http protocol also, to standardize flow
if proxy.proxy_host and not socket and proxy.proxy_protocol != "http":
return _start_proxied_socket(url, options, proxy)
hostname, port_from_url, resource, is_secure = parse_url(url)
if socket:
return socket, (hostname, port_from_url, resource)
addrinfo_list, need_tunnel, auth = _get_addrinfo_list(
hostname, port_from_url, is_secure, proxy
)
if not addrinfo_list:
raise WebSocketException(f"Host not found.: {hostname}:{port_from_url}")
sock = None
try:
sock = _open_socket(addrinfo_list, options.sockopt, options.timeout)
if need_tunnel:
sock = _tunnel(sock, hostname, port_from_url, auth)
if is_secure:
if HAVE_SSL:
sock = _ssl_socket(sock, options.sslopt, hostname)
else:
raise WebSocketException("SSL not available.")
return sock, (hostname, port_from_url, resource)
except:
if sock:
sock.close()
raise
def _get_addrinfo_list(hostname, port: int, is_secure: bool, proxy) -> tuple:
phost, pport, pauth = get_proxy_info(
hostname,
is_secure,
proxy.proxy_host,
proxy.proxy_port,
proxy.auth,
proxy.no_proxy,
)
try:
# when running on windows 10, getaddrinfo without socktype returns a socktype 0.
# This generates an error exception: `_on_error: exception Socket type must be stream or datagram, not 0`
# or `OSError: [Errno 22] Invalid argument` when creating socket. Force the socket type to SOCK_STREAM.
if not phost:
addrinfo_list = socket.getaddrinfo(
hostname, port, 0, socket.SOCK_STREAM, socket.SOL_TCP
)
return addrinfo_list, False, None
else:
pport = pport and pport or 80
# when running on windows 10, the getaddrinfo used above
# returns a socktype 0. This generates an error exception:
# _on_error: exception Socket type must be stream or datagram, not 0
# Force the socket type to SOCK_STREAM
addrinfo_list = socket.getaddrinfo(
phost, pport, 0, socket.SOCK_STREAM, socket.SOL_TCP
)
return addrinfo_list, True, pauth
except socket.gaierror as e:
raise WebSocketAddressException(e)
def _open_socket(addrinfo_list, sockopt, timeout):
err = None
for addrinfo in addrinfo_list:
family, socktype, proto = addrinfo[:3]
sock = socket.socket(family, socktype, proto)
sock.settimeout(timeout)
for opts in DEFAULT_SOCKET_OPTION:
sock.setsockopt(*opts)
for opts in sockopt:
sock.setsockopt(*opts)
address = addrinfo[4]
err = None
while not err:
try:
sock.connect(address)
except socket.error as error:
sock.close()
error.remote_ip = str(address[0])
try:
eConnRefused = (
errno.ECONNREFUSED,
errno.WSAECONNREFUSED,
errno.ENETUNREACH,
)
except AttributeError:
eConnRefused = (errno.ECONNREFUSED, errno.ENETUNREACH)
if error.errno not in eConnRefused:
raise error
err = error
continue
else:
break
else:
continue
break
else:
if err:
raise err
return sock
def _wrap_sni_socket(sock: socket.socket, sslopt: dict, hostname, check_hostname):
context = sslopt.get("context", None)
if not context:
context = ssl.SSLContext(sslopt.get("ssl_version", ssl.PROTOCOL_TLS_CLIENT))
# Non default context need to manually enable SSLKEYLOGFILE support by setting the keylog_filename attribute.
# For more details see also:
# * https://docs.python.org/3.8/library/ssl.html?highlight=sslkeylogfile#context-creation
# * https://docs.python.org/3.8/library/ssl.html?highlight=sslkeylogfile#ssl.SSLContext.keylog_filename
context.keylog_filename = os.environ.get("SSLKEYLOGFILE", None)
if sslopt.get("cert_reqs", ssl.CERT_NONE) != ssl.CERT_NONE:
cafile = sslopt.get("ca_certs", None)
capath = sslopt.get("ca_cert_path", None)
if cafile or capath:
context.load_verify_locations(cafile=cafile, capath=capath)
elif hasattr(context, "load_default_certs"):
context.load_default_certs(ssl.Purpose.SERVER_AUTH)
if sslopt.get("certfile", None):
context.load_cert_chain(
sslopt["certfile"],
sslopt.get("keyfile", None),
sslopt.get("password", None),
)
# Python 3.10 switch to PROTOCOL_TLS_CLIENT defaults to "cert_reqs = ssl.CERT_REQUIRED" and "check_hostname = True"
# If both disabled, set check_hostname before verify_mode
# see https://github.com/liris/websocket-client/commit/b96a2e8fa765753e82eea531adb19716b52ca3ca#commitcomment-10803153
if sslopt.get("cert_reqs", ssl.CERT_NONE) == ssl.CERT_NONE and not sslopt.get(
"check_hostname", False
):
context.check_hostname = False
context.verify_mode = ssl.CERT_NONE
else:
context.check_hostname = sslopt.get("check_hostname", True)
context.verify_mode = sslopt.get("cert_reqs", ssl.CERT_REQUIRED)
if "ciphers" in sslopt:
context.set_ciphers(sslopt["ciphers"])
if "cert_chain" in sslopt:
certfile, keyfile, password = sslopt["cert_chain"]
context.load_cert_chain(certfile, keyfile, password)
if "ecdh_curve" in sslopt:
context.set_ecdh_curve(sslopt["ecdh_curve"])
return context.wrap_socket(
sock,
do_handshake_on_connect=sslopt.get("do_handshake_on_connect", True),
suppress_ragged_eofs=sslopt.get("suppress_ragged_eofs", True),
server_hostname=hostname,
)
def _ssl_socket(sock: socket.socket, user_sslopt: dict, hostname):
sslopt: dict = {"cert_reqs": ssl.CERT_REQUIRED}
sslopt.update(user_sslopt)
cert_path = os.environ.get("WEBSOCKET_CLIENT_CA_BUNDLE")
if (
cert_path
and os.path.isfile(cert_path)
and user_sslopt.get("ca_certs", None) is None
):
sslopt["ca_certs"] = cert_path
elif (
cert_path
and os.path.isdir(cert_path)
and user_sslopt.get("ca_cert_path", None) is None
):
sslopt["ca_cert_path"] = cert_path
if sslopt.get("server_hostname", None):
hostname = sslopt["server_hostname"]
check_hostname = sslopt.get("check_hostname", True)
sock = _wrap_sni_socket(sock, sslopt, hostname, check_hostname)
return sock
def _tunnel(sock: socket.socket, host, port: int, auth) -> socket.socket:
debug("Connecting proxy...")
connect_header = f"CONNECT {host}:{port} HTTP/1.1\r\n"
connect_header += f"Host: {host}:{port}\r\n"
# TODO: support digest auth.
if auth and auth[0]:
auth_str = auth[0]
if auth[1]:
auth_str += f":{auth[1]}"
encoded_str = base64encode(auth_str.encode()).strip().decode().replace("\n", "")
connect_header += f"Proxy-Authorization: Basic {encoded_str}\r\n"
connect_header += "\r\n"
dump("request header", connect_header)
send(sock, connect_header)
try:
status, _, _ = read_headers(sock)
except Exception as e:
raise WebSocketProxyException(str(e))
if status != 200:
raise WebSocketProxyException(f"failed CONNECT via proxy status: {status}")
return sock
def read_headers(sock: socket.socket) -> tuple:
status = None
status_message = None
headers: dict = {}
trace("--- response header ---")
while True:
line = recv_line(sock)
line = line.decode("utf-8").strip()
if not line:
break
trace(line)
if not status:
status_info = line.split(" ", 2)
status = int(status_info[1])
if len(status_info) > 2:
status_message = status_info[2]
else:
kv = line.split(":", 1)
if len(kv) != 2:
raise WebSocketException("Invalid header")
key, value = kv
if key.lower() == "set-cookie" and headers.get("set-cookie"):
headers["set-cookie"] = headers.get("set-cookie") + "; " + value.strip()
else:
headers[key.lower()] = value.strip()
trace("-----------------------")
return status, headers, status_message

View File

@@ -0,0 +1,106 @@
import logging
"""
_logging.py
websocket - WebSocket client library for Python
Copyright 2024 engn33r
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
_logger = logging.getLogger("websocket")
try:
from logging import NullHandler
except ImportError:
class NullHandler(logging.Handler):
def emit(self, record) -> None:
pass
_logger.addHandler(NullHandler())
_traceEnabled = False
__all__ = [
"enableTrace",
"dump",
"error",
"warning",
"debug",
"trace",
"isEnabledForError",
"isEnabledForDebug",
"isEnabledForTrace",
]
def enableTrace(
traceable: bool,
handler: logging.StreamHandler = logging.StreamHandler(),
level: str = "DEBUG",
) -> None:
"""
Turn on/off the traceability.
Parameters
----------
traceable: bool
If set to True, traceability is enabled.
"""
global _traceEnabled
_traceEnabled = traceable
if traceable:
_logger.addHandler(handler)
_logger.setLevel(getattr(logging, level))
def dump(title: str, message: str) -> None:
if _traceEnabled:
_logger.debug(f"--- {title} ---")
_logger.debug(message)
_logger.debug("-----------------------")
def error(msg: str) -> None:
_logger.error(msg)
def warning(msg: str) -> None:
_logger.warning(msg)
def debug(msg: str) -> None:
_logger.debug(msg)
def info(msg: str) -> None:
_logger.info(msg)
def trace(msg: str) -> None:
if _traceEnabled:
_logger.debug(msg)
def isEnabledForError() -> bool:
return _logger.isEnabledFor(logging.ERROR)
def isEnabledForDebug() -> bool:
return _logger.isEnabledFor(logging.DEBUG)
def isEnabledForTrace() -> bool:
return _traceEnabled

View File

@@ -0,0 +1,188 @@
import errno
import selectors
import socket
from typing import Union
from ._exceptions import (
WebSocketConnectionClosedException,
WebSocketTimeoutException,
)
from ._ssl_compat import SSLError, SSLWantReadError, SSLWantWriteError
from ._utils import extract_error_code, extract_err_message
"""
_socket.py
websocket - WebSocket client library for Python
Copyright 2024 engn33r
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
DEFAULT_SOCKET_OPTION = [(socket.SOL_TCP, socket.TCP_NODELAY, 1)]
if hasattr(socket, "SO_KEEPALIVE"):
DEFAULT_SOCKET_OPTION.append((socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1))
if hasattr(socket, "TCP_KEEPIDLE"):
DEFAULT_SOCKET_OPTION.append((socket.SOL_TCP, socket.TCP_KEEPIDLE, 30))
if hasattr(socket, "TCP_KEEPINTVL"):
DEFAULT_SOCKET_OPTION.append((socket.SOL_TCP, socket.TCP_KEEPINTVL, 10))
if hasattr(socket, "TCP_KEEPCNT"):
DEFAULT_SOCKET_OPTION.append((socket.SOL_TCP, socket.TCP_KEEPCNT, 3))
_default_timeout = None
__all__ = [
"DEFAULT_SOCKET_OPTION",
"sock_opt",
"setdefaulttimeout",
"getdefaulttimeout",
"recv",
"recv_line",
"send",
]
class sock_opt:
def __init__(self, sockopt: list, sslopt: dict) -> None:
if sockopt is None:
sockopt = []
if sslopt is None:
sslopt = {}
self.sockopt = sockopt
self.sslopt = sslopt
self.timeout = None
def setdefaulttimeout(timeout: Union[int, float, None]) -> None:
"""
Set the global timeout setting to connect.
Parameters
----------
timeout: int or float
default socket timeout time (in seconds)
"""
global _default_timeout
_default_timeout = timeout
def getdefaulttimeout() -> Union[int, float, None]:
"""
Get default timeout
Returns
----------
_default_timeout: int or float
Return the global timeout setting (in seconds) to connect.
"""
return _default_timeout
def recv(sock: socket.socket, bufsize: int) -> bytes:
if not sock:
raise WebSocketConnectionClosedException("socket is already closed.")
def _recv():
try:
return sock.recv(bufsize)
except SSLWantReadError:
pass
except socket.error as exc:
error_code = extract_error_code(exc)
if error_code not in [errno.EAGAIN, errno.EWOULDBLOCK]:
raise
sel = selectors.DefaultSelector()
sel.register(sock, selectors.EVENT_READ)
r = sel.select(sock.gettimeout())
sel.close()
if r:
return sock.recv(bufsize)
try:
if sock.gettimeout() == 0:
bytes_ = sock.recv(bufsize)
else:
bytes_ = _recv()
except TimeoutError:
raise WebSocketTimeoutException("Connection timed out")
except socket.timeout as e:
message = extract_err_message(e)
raise WebSocketTimeoutException(message)
except SSLError as e:
message = extract_err_message(e)
if isinstance(message, str) and "timed out" in message:
raise WebSocketTimeoutException(message)
else:
raise
if not bytes_:
raise WebSocketConnectionClosedException("Connection to remote host was lost.")
return bytes_
def recv_line(sock: socket.socket) -> bytes:
line = []
while True:
c = recv(sock, 1)
line.append(c)
if c == b"\n":
break
return b"".join(line)
def send(sock: socket.socket, data: Union[bytes, str]) -> int:
if isinstance(data, str):
data = data.encode("utf-8")
if not sock:
raise WebSocketConnectionClosedException("socket is already closed.")
def _send():
try:
return sock.send(data)
except SSLWantWriteError:
pass
except socket.error as exc:
error_code = extract_error_code(exc)
if error_code is None:
raise
if error_code not in [errno.EAGAIN, errno.EWOULDBLOCK]:
raise
sel = selectors.DefaultSelector()
sel.register(sock, selectors.EVENT_WRITE)
w = sel.select(sock.gettimeout())
sel.close()
if w:
return sock.send(data)
try:
if sock.gettimeout() == 0:
return sock.send(data)
else:
return _send()
except socket.timeout as e:
message = extract_err_message(e)
raise WebSocketTimeoutException(message)
except Exception as e:
message = extract_err_message(e)
if isinstance(message, str) and "timed out" in message:
raise WebSocketTimeoutException(message)
else:
raise

View File

@@ -0,0 +1,49 @@
"""
_ssl_compat.py
websocket - WebSocket client library for Python
Copyright 2024 engn33r
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
__all__ = [
"HAVE_SSL",
"ssl",
"SSLError",
"SSLEOFError",
"SSLWantReadError",
"SSLWantWriteError",
]
try:
import ssl
from ssl import SSLError, SSLEOFError, SSLWantReadError, SSLWantWriteError
HAVE_SSL = True
except ImportError:
# dummy class of SSLError for environment without ssl support
class SSLError(Exception):
pass
class SSLEOFError(Exception):
pass
class SSLWantReadError(Exception):
pass
class SSLWantWriteError(Exception):
pass
ssl = None
HAVE_SSL = False

View File

@@ -0,0 +1,190 @@
import os
import socket
import struct
from typing import Optional
from urllib.parse import unquote, urlparse
from ._exceptions import WebSocketProxyException
"""
_url.py
websocket - WebSocket client library for Python
Copyright 2024 engn33r
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
__all__ = ["parse_url", "get_proxy_info"]
def parse_url(url: str) -> tuple:
"""
parse url and the result is tuple of
(hostname, port, resource path and the flag of secure mode)
Parameters
----------
url: str
url string.
"""
if ":" not in url:
raise ValueError("url is invalid")
scheme, url = url.split(":", 1)
parsed = urlparse(url, scheme="http")
if parsed.hostname:
hostname = parsed.hostname
else:
raise ValueError("hostname is invalid")
port = 0
if parsed.port:
port = parsed.port
is_secure = False
if scheme == "ws":
if not port:
port = 80
elif scheme == "wss":
is_secure = True
if not port:
port = 443
else:
raise ValueError("scheme %s is invalid" % scheme)
if parsed.path:
resource = parsed.path
else:
resource = "/"
if parsed.query:
resource += f"?{parsed.query}"
return hostname, port, resource, is_secure
DEFAULT_NO_PROXY_HOST = ["localhost", "127.0.0.1"]
def _is_ip_address(addr: str) -> bool:
try:
socket.inet_aton(addr)
except socket.error:
return False
else:
return True
def _is_subnet_address(hostname: str) -> bool:
try:
addr, netmask = hostname.split("/")
return _is_ip_address(addr) and 0 <= int(netmask) < 32
except ValueError:
return False
def _is_address_in_network(ip: str, net: str) -> bool:
ipaddr: int = struct.unpack("!I", socket.inet_aton(ip))[0]
netaddr, netmask = net.split("/")
netaddr: int = struct.unpack("!I", socket.inet_aton(netaddr))[0]
netmask = (0xFFFFFFFF << (32 - int(netmask))) & 0xFFFFFFFF
return ipaddr & netmask == netaddr
def _is_no_proxy_host(hostname: str, no_proxy: Optional[list]) -> bool:
if not no_proxy:
if v := os.environ.get("no_proxy", os.environ.get("NO_PROXY", "")).replace(
" ", ""
):
no_proxy = v.split(",")
if not no_proxy:
no_proxy = DEFAULT_NO_PROXY_HOST
if "*" in no_proxy:
return True
if hostname in no_proxy:
return True
if _is_ip_address(hostname):
return any(
[
_is_address_in_network(hostname, subnet)
for subnet in no_proxy
if _is_subnet_address(subnet)
]
)
for domain in [domain for domain in no_proxy if domain.startswith(".")]:
if hostname.endswith(domain):
return True
return False
def get_proxy_info(
hostname: str,
is_secure: bool,
proxy_host: Optional[str] = None,
proxy_port: int = 0,
proxy_auth: Optional[tuple] = None,
no_proxy: Optional[list] = None,
proxy_type: str = "http",
) -> tuple:
"""
Try to retrieve proxy host and port from environment
if not provided in options.
Result is (proxy_host, proxy_port, proxy_auth).
proxy_auth is tuple of username and password
of proxy authentication information.
Parameters
----------
hostname: str
Websocket server name.
is_secure: bool
Is the connection secure? (wss) looks for "https_proxy" in env
instead of "http_proxy"
proxy_host: str
http proxy host name.
proxy_port: str or int
http proxy port.
no_proxy: list
Whitelisted host names that don't use the proxy.
proxy_auth: tuple
HTTP proxy auth information. Tuple of username and password. Default is None.
proxy_type: str
Specify the proxy protocol (http, socks4, socks4a, socks5, socks5h). Default is "http".
Use socks4a or socks5h if you want to send DNS requests through the proxy.
"""
if _is_no_proxy_host(hostname, no_proxy):
return None, 0, None
if proxy_host:
if not proxy_port:
raise WebSocketProxyException("Cannot use port 0 when proxy_host specified")
port = proxy_port
auth = proxy_auth
return proxy_host, port, auth
env_key = "https_proxy" if is_secure else "http_proxy"
value = os.environ.get(env_key, os.environ.get(env_key.upper(), "")).replace(
" ", ""
)
if value:
proxy = urlparse(value)
auth = (
(unquote(proxy.username), unquote(proxy.password))
if proxy.username
else None
)
return proxy.hostname, proxy.port, auth
return None, 0, None

View File

@@ -0,0 +1,459 @@
from typing import Union
"""
_url.py
websocket - WebSocket client library for Python
Copyright 2024 engn33r
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
__all__ = ["NoLock", "validate_utf8", "extract_err_message", "extract_error_code"]
class NoLock:
def __enter__(self) -> None:
pass
def __exit__(self, exc_type, exc_value, traceback) -> None:
pass
try:
# If wsaccel is available we use compiled routines to validate UTF-8
# strings.
from wsaccel.utf8validator import Utf8Validator
def _validate_utf8(utfbytes: Union[str, bytes]) -> bool:
result: bool = Utf8Validator().validate(utfbytes)[0]
return result
except ImportError:
# UTF-8 validator
# python implementation of http://bjoern.hoehrmann.de/utf-8/decoder/dfa/
_UTF8_ACCEPT = 0
_UTF8_REJECT = 12
_UTF8D = [
# The first part of the table maps bytes to character classes that
# to reduce the size of the transition table and create bitmasks.
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
9,
9,
9,
9,
9,
9,
9,
9,
9,
9,
9,
9,
9,
9,
9,
9,
7,
7,
7,
7,
7,
7,
7,
7,
7,
7,
7,
7,
7,
7,
7,
7,
7,
7,
7,
7,
7,
7,
7,
7,
7,
7,
7,
7,
7,
7,
7,
7,
8,
8,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
10,
3,
3,
3,
3,
3,
3,
3,
3,
3,
3,
3,
3,
4,
3,
3,
11,
6,
6,
6,
5,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
# The second part is a transition table that maps a combination
# of a state of the automaton and a character class to a state.
0,
12,
24,
36,
60,
96,
84,
12,
12,
12,
48,
72,
12,
12,
12,
12,
12,
12,
12,
12,
12,
12,
12,
12,
12,
0,
12,
12,
12,
12,
12,
0,
12,
0,
12,
12,
12,
24,
12,
12,
12,
12,
12,
24,
12,
24,
12,
12,
12,
12,
12,
12,
12,
12,
12,
24,
12,
12,
12,
12,
12,
24,
12,
12,
12,
12,
12,
12,
12,
24,
12,
12,
12,
12,
12,
12,
12,
12,
12,
36,
12,
36,
12,
12,
12,
36,
12,
12,
12,
12,
12,
36,
12,
36,
12,
12,
12,
36,
12,
12,
12,
12,
12,
12,
12,
12,
12,
12,
]
def _decode(state: int, codep: int, ch: int) -> tuple:
tp = _UTF8D[ch]
codep = (
(ch & 0x3F) | (codep << 6) if (state != _UTF8_ACCEPT) else (0xFF >> tp) & ch
)
state = _UTF8D[256 + state + tp]
return state, codep
def _validate_utf8(utfbytes: Union[str, bytes]) -> bool:
state = _UTF8_ACCEPT
codep = 0
for i in utfbytes:
state, codep = _decode(state, codep, int(i))
if state == _UTF8_REJECT:
return False
return True
def validate_utf8(utfbytes: Union[str, bytes]) -> bool:
"""
validate utf8 byte string.
utfbytes: utf byte string to check.
return value: if valid utf8 string, return true. Otherwise, return false.
"""
return _validate_utf8(utfbytes)
def extract_err_message(exception: Exception) -> Union[str, None]:
if exception.args:
exception_message: str = exception.args[0]
return exception_message
else:
return None
def extract_error_code(exception: Exception) -> Union[int, None]:
if exception.args and len(exception.args) > 1:
return exception.args[0] if isinstance(exception.args[0], int) else None

View File

@@ -0,0 +1,244 @@
#!/usr/bin/env python3
"""
wsdump.py
websocket - WebSocket client library for Python
Copyright 2024 engn33r
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import argparse
import code
import gzip
import ssl
import sys
import threading
import time
import zlib
from urllib.parse import urlparse
import websocket
try:
import readline
except ImportError:
pass
def get_encoding() -> str:
encoding = getattr(sys.stdin, "encoding", "")
if not encoding:
return "utf-8"
else:
return encoding.lower()
OPCODE_DATA = (websocket.ABNF.OPCODE_TEXT, websocket.ABNF.OPCODE_BINARY)
ENCODING = get_encoding()
class VAction(argparse.Action):
def __call__(
self,
parser: argparse.Namespace,
args: tuple,
values: str,
option_string: str = None,
) -> None:
if values is None:
values = "1"
try:
values = int(values)
except ValueError:
values = values.count("v") + 1
setattr(args, self.dest, values)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="WebSocket Simple Dump Tool")
parser.add_argument(
"url", metavar="ws_url", help="websocket url. ex. ws://echo.websocket.events/"
)
parser.add_argument("-p", "--proxy", help="proxy url. ex. http://127.0.0.1:8080")
parser.add_argument(
"-v",
"--verbose",
default=0,
nargs="?",
action=VAction,
dest="verbose",
help="set verbose mode. If set to 1, show opcode. "
"If set to 2, enable to trace websocket module",
)
parser.add_argument(
"-n", "--nocert", action="store_true", help="Ignore invalid SSL cert"
)
parser.add_argument("-r", "--raw", action="store_true", help="raw output")
parser.add_argument("-s", "--subprotocols", nargs="*", help="Set subprotocols")
parser.add_argument("-o", "--origin", help="Set origin")
parser.add_argument(
"--eof-wait",
default=0,
type=int,
help="wait time(second) after 'EOF' received.",
)
parser.add_argument("-t", "--text", help="Send initial text")
parser.add_argument(
"--timings", action="store_true", help="Print timings in seconds"
)
parser.add_argument("--headers", help="Set custom headers. Use ',' as separator")
return parser.parse_args()
class RawInput:
def raw_input(self, prompt: str = "") -> str:
line = input(prompt)
if ENCODING and ENCODING != "utf-8" and not isinstance(line, str):
line = line.decode(ENCODING).encode("utf-8")
elif isinstance(line, str):
line = line.encode("utf-8")
return line
class InteractiveConsole(RawInput, code.InteractiveConsole):
def write(self, data: str) -> None:
sys.stdout.write("\033[2K\033[E")
# sys.stdout.write("\n")
sys.stdout.write("\033[34m< " + data + "\033[39m")
sys.stdout.write("\n> ")
sys.stdout.flush()
def read(self) -> str:
return self.raw_input("> ")
class NonInteractive(RawInput):
def write(self, data: str) -> None:
sys.stdout.write(data)
sys.stdout.write("\n")
sys.stdout.flush()
def read(self) -> str:
return self.raw_input("")
def main() -> None:
start_time = time.time()
args = parse_args()
if args.verbose > 1:
websocket.enableTrace(True)
options = {}
if args.proxy:
p = urlparse(args.proxy)
options["http_proxy_host"] = p.hostname
options["http_proxy_port"] = p.port
if args.origin:
options["origin"] = args.origin
if args.subprotocols:
options["subprotocols"] = args.subprotocols
opts = {}
if args.nocert:
opts = {"cert_reqs": ssl.CERT_NONE, "check_hostname": False}
if args.headers:
options["header"] = list(map(str.strip, args.headers.split(",")))
ws = websocket.create_connection(args.url, sslopt=opts, **options)
if args.raw:
console = NonInteractive()
else:
console = InteractiveConsole()
print("Press Ctrl+C to quit")
def recv() -> tuple:
try:
frame = ws.recv_frame()
except websocket.WebSocketException:
return websocket.ABNF.OPCODE_CLOSE, ""
if not frame:
raise websocket.WebSocketException(f"Not a valid frame {frame}")
elif frame.opcode in OPCODE_DATA:
return frame.opcode, frame.data
elif frame.opcode == websocket.ABNF.OPCODE_CLOSE:
ws.send_close()
return frame.opcode, ""
elif frame.opcode == websocket.ABNF.OPCODE_PING:
ws.pong(frame.data)
return frame.opcode, frame.data
return frame.opcode, frame.data
def recv_ws() -> None:
while True:
opcode, data = recv()
msg = None
if opcode == websocket.ABNF.OPCODE_TEXT and isinstance(data, bytes):
data = str(data, "utf-8")
if (
isinstance(data, bytes) and len(data) > 2 and data[:2] == b"\037\213"
): # gzip magick
try:
data = "[gzip] " + str(gzip.decompress(data), "utf-8")
except:
pass
elif isinstance(data, bytes):
try:
data = "[zlib] " + str(
zlib.decompress(data, -zlib.MAX_WBITS), "utf-8"
)
except:
pass
if isinstance(data, bytes):
data = repr(data)
if args.verbose:
msg = f"{websocket.ABNF.OPCODE_MAP.get(opcode)}: {data}"
else:
msg = data
if msg is not None:
if args.timings:
console.write(f"{time.time() - start_time}: {msg}")
else:
console.write(msg)
if opcode == websocket.ABNF.OPCODE_CLOSE:
break
thread = threading.Thread(target=recv_ws)
thread.daemon = True
thread.start()
if args.text:
ws.send(args.text)
while True:
try:
message = console.read()
ws.send(message)
except KeyboardInterrupt:
return
except EOFError:
time.sleep(args.eof_wait)
return
if __name__ == "__main__":
try:
main()
except Exception as e:
print(e)

View File

@@ -0,0 +1,6 @@
HTTP/1.1 101 WebSocket Protocol Handshake
Connection: Upgrade
Upgrade: WebSocket
Sec-WebSocket-Accept: Kxep+hNu9n51529fGidYu7a3wO0=
some_header: something

View File

@@ -0,0 +1,6 @@
HTTP/1.1 101 WebSocket Protocol Handshake
Connection: Upgrade
Upgrade WebSocket
Sec-WebSocket-Accept: Kxep+hNu9n51529fGidYu7a3wO0=
some_header: something

View File

@@ -0,0 +1,8 @@
HTTP/1.1 101 WebSocket Protocol Handshake
Connection: Upgrade, Keep-Alive
Upgrade: WebSocket
Sec-WebSocket-Accept: Kxep+hNu9n51529fGidYu7a3wO0=
Set-Cookie: Token=ABCDE
Set-Cookie: Token=FGHIJ
some_header: something

View File

@@ -0,0 +1,23 @@
#!/usr/bin/env python
# From https://github.com/aaugustin/websockets/blob/main/example/echo.py
import asyncio
import os
import websockets
LOCAL_WS_SERVER_PORT = int(os.environ.get("LOCAL_WS_SERVER_PORT", "8765"))
async def echo(websocket):
async for message in websocket:
await websocket.send(message)
async def main():
async with websockets.serve(echo, "localhost", LOCAL_WS_SERVER_PORT):
await asyncio.Future() # run forever
asyncio.run(main())

View File

@@ -0,0 +1,125 @@
# -*- coding: utf-8 -*-
#
import unittest
from websocket._abnf import ABNF, frame_buffer
from websocket._exceptions import WebSocketProtocolException
"""
test_abnf.py
websocket - WebSocket client library for Python
Copyright 2024 engn33r
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
class ABNFTest(unittest.TestCase):
def test_init(self):
a = ABNF(0, 0, 0, 0, opcode=ABNF.OPCODE_PING)
self.assertEqual(a.fin, 0)
self.assertEqual(a.rsv1, 0)
self.assertEqual(a.rsv2, 0)
self.assertEqual(a.rsv3, 0)
self.assertEqual(a.opcode, 9)
self.assertEqual(a.data, "")
a_bad = ABNF(0, 1, 0, 0, opcode=77)
self.assertEqual(a_bad.rsv1, 1)
self.assertEqual(a_bad.opcode, 77)
def test_validate(self):
a_invalid_ping = ABNF(0, 0, 0, 0, opcode=ABNF.OPCODE_PING)
self.assertRaises(
WebSocketProtocolException,
a_invalid_ping.validate,
skip_utf8_validation=False,
)
a_bad_rsv_value = ABNF(0, 1, 0, 0, opcode=ABNF.OPCODE_TEXT)
self.assertRaises(
WebSocketProtocolException,
a_bad_rsv_value.validate,
skip_utf8_validation=False,
)
a_bad_opcode = ABNF(0, 0, 0, 0, opcode=77)
self.assertRaises(
WebSocketProtocolException,
a_bad_opcode.validate,
skip_utf8_validation=False,
)
a_bad_close_frame = ABNF(0, 0, 0, 0, opcode=ABNF.OPCODE_CLOSE, data=b"\x01")
self.assertRaises(
WebSocketProtocolException,
a_bad_close_frame.validate,
skip_utf8_validation=False,
)
a_bad_close_frame_2 = ABNF(
0, 0, 0, 0, opcode=ABNF.OPCODE_CLOSE, data=b"\x01\x8a\xaa\xff\xdd"
)
self.assertRaises(
WebSocketProtocolException,
a_bad_close_frame_2.validate,
skip_utf8_validation=False,
)
a_bad_close_frame_3 = ABNF(
0, 0, 0, 0, opcode=ABNF.OPCODE_CLOSE, data=b"\x03\xe7"
)
self.assertRaises(
WebSocketProtocolException,
a_bad_close_frame_3.validate,
skip_utf8_validation=True,
)
def test_mask(self):
abnf_none_data = ABNF(
0, 0, 0, 0, opcode=ABNF.OPCODE_PING, mask_value=1, data=None
)
bytes_val = b"aaaa"
self.assertEqual(abnf_none_data._get_masked(bytes_val), bytes_val)
abnf_str_data = ABNF(
0, 0, 0, 0, opcode=ABNF.OPCODE_PING, mask_value=1, data="a"
)
self.assertEqual(abnf_str_data._get_masked(bytes_val), b"aaaa\x00")
def test_format(self):
abnf_bad_rsv_bits = ABNF(2, 0, 0, 0, opcode=ABNF.OPCODE_TEXT)
self.assertRaises(ValueError, abnf_bad_rsv_bits.format)
abnf_bad_opcode = ABNF(0, 0, 0, 0, opcode=5)
self.assertRaises(ValueError, abnf_bad_opcode.format)
abnf_length_10 = ABNF(0, 0, 0, 0, opcode=ABNF.OPCODE_TEXT, data="abcdefghij")
self.assertEqual(b"\x01", abnf_length_10.format()[0].to_bytes(1, "big"))
self.assertEqual(b"\x8a", abnf_length_10.format()[1].to_bytes(1, "big"))
self.assertEqual("fin=0 opcode=1 data=abcdefghij", abnf_length_10.__str__())
abnf_length_20 = ABNF(
0, 0, 0, 0, opcode=ABNF.OPCODE_BINARY, data="abcdefghijabcdefghij"
)
self.assertEqual(b"\x02", abnf_length_20.format()[0].to_bytes(1, "big"))
self.assertEqual(b"\x94", abnf_length_20.format()[1].to_bytes(1, "big"))
abnf_no_mask = ABNF(
0, 0, 0, 0, opcode=ABNF.OPCODE_TEXT, mask_value=0, data=b"\x01\x8a\xcc"
)
self.assertEqual(b"\x01\x03\x01\x8a\xcc", abnf_no_mask.format())
def test_frame_buffer(self):
fb = frame_buffer(0, True)
self.assertEqual(fb.recv, 0)
self.assertEqual(fb.skip_utf8_validation, True)
fb.clear
self.assertEqual(fb.header, None)
self.assertEqual(fb.length, None)
self.assertEqual(fb.mask_value, None)
self.assertEqual(fb.has_mask(), False)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,352 @@
# -*- coding: utf-8 -*-
#
import os
import os.path
import ssl
import threading
import unittest
import websocket as ws
"""
test_app.py
websocket - WebSocket client library for Python
Copyright 2024 engn33r
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
# Skip test to access the internet unless TEST_WITH_INTERNET == 1
TEST_WITH_INTERNET = os.environ.get("TEST_WITH_INTERNET", "0") == "1"
# Skip tests relying on local websockets server unless LOCAL_WS_SERVER_PORT != -1
LOCAL_WS_SERVER_PORT = os.environ.get("LOCAL_WS_SERVER_PORT", "-1")
TEST_WITH_LOCAL_SERVER = LOCAL_WS_SERVER_PORT != "-1"
TRACEABLE = True
class WebSocketAppTest(unittest.TestCase):
class NotSetYet:
"""A marker class for signalling that a value hasn't been set yet."""
def setUp(self):
ws.enableTrace(TRACEABLE)
WebSocketAppTest.keep_running_open = WebSocketAppTest.NotSetYet()
WebSocketAppTest.keep_running_close = WebSocketAppTest.NotSetYet()
WebSocketAppTest.get_mask_key_id = WebSocketAppTest.NotSetYet()
WebSocketAppTest.on_error_data = WebSocketAppTest.NotSetYet()
def tearDown(self):
WebSocketAppTest.keep_running_open = WebSocketAppTest.NotSetYet()
WebSocketAppTest.keep_running_close = WebSocketAppTest.NotSetYet()
WebSocketAppTest.get_mask_key_id = WebSocketAppTest.NotSetYet()
WebSocketAppTest.on_error_data = WebSocketAppTest.NotSetYet()
def close(self):
pass
@unittest.skipUnless(
TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled"
)
def test_keep_running(self):
"""A WebSocketApp should keep running as long as its self.keep_running
is not False (in the boolean context).
"""
def on_open(self, *args, **kwargs):
"""Set the keep_running flag for later inspection and immediately
close the connection.
"""
self.send("hello!")
WebSocketAppTest.keep_running_open = self.keep_running
self.keep_running = False
def on_message(_, message):
print(message)
self.close()
def on_close(self, *args, **kwargs):
"""Set the keep_running flag for the test to use."""
WebSocketAppTest.keep_running_close = self.keep_running
app = ws.WebSocketApp(
f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}",
on_open=on_open,
on_close=on_close,
on_message=on_message,
)
app.run_forever()
# @unittest.skipUnless(TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled")
@unittest.skipUnless(False, "Test disabled for now (requires rel)")
def test_run_forever_dispatcher(self):
"""A WebSocketApp should keep running as long as its self.keep_running
is not False (in the boolean context).
"""
def on_open(self, *args, **kwargs):
"""Send a message, receive, and send one more"""
self.send("hello!")
self.recv()
self.send("goodbye!")
def on_message(_, message):
print(message)
self.close()
app = ws.WebSocketApp(
f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}",
on_open=on_open,
on_message=on_message,
)
app.run_forever(dispatcher="Dispatcher") # doesn't work
# app.run_forever(dispatcher=rel) # would work
# rel.dispatch()
@unittest.skipUnless(
TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled"
)
def test_run_forever_teardown_clean_exit(self):
"""The WebSocketApp.run_forever() method should return `False` when the application ends gracefully."""
app = ws.WebSocketApp(f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}")
threading.Timer(interval=0.2, function=app.close).start()
teardown = app.run_forever()
self.assertEqual(teardown, False)
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def test_sock_mask_key(self):
"""A WebSocketApp should forward the received mask_key function down
to the actual socket.
"""
def my_mask_key_func():
return "\x00\x00\x00\x00"
app = ws.WebSocketApp(
"wss://api-pub.bitfinex.com/ws/1", get_mask_key=my_mask_key_func
)
# if numpy is installed, this assertion fail
# Note: We can't use 'is' for comparing the functions directly, need to use 'id'.
self.assertEqual(id(app.get_mask_key), id(my_mask_key_func))
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def test_invalid_ping_interval_ping_timeout(self):
"""Test exception handling if ping_interval < ping_timeout"""
def on_ping(app, _):
print("Got a ping!")
app.close()
def on_pong(app, _):
print("Got a pong! No need to respond")
app.close()
app = ws.WebSocketApp(
"wss://api-pub.bitfinex.com/ws/1", on_ping=on_ping, on_pong=on_pong
)
self.assertRaises(
ws.WebSocketException,
app.run_forever,
ping_interval=1,
ping_timeout=2,
sslopt={"cert_reqs": ssl.CERT_NONE},
)
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def test_ping_interval(self):
"""Test WebSocketApp proper ping functionality"""
def on_ping(app, _):
print("Got a ping!")
app.close()
def on_pong(app, _):
print("Got a pong! No need to respond")
app.close()
app = ws.WebSocketApp(
"wss://api-pub.bitfinex.com/ws/1", on_ping=on_ping, on_pong=on_pong
)
app.run_forever(
ping_interval=2, ping_timeout=1, sslopt={"cert_reqs": ssl.CERT_NONE}
)
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def test_opcode_close(self):
"""Test WebSocketApp close opcode"""
app = ws.WebSocketApp("wss://tsock.us1.twilio.com/v3/wsconnect")
app.run_forever(ping_interval=2, ping_timeout=1, ping_payload="Ping payload")
# This is commented out because the URL no longer responds in the expected way
# @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
# def testOpcodeBinary(self):
# """ Test WebSocketApp binary opcode
# """
# app = ws.WebSocketApp('wss://streaming.vn.teslamotors.com/streaming/')
# app.run_forever(ping_interval=2, ping_timeout=1, ping_payload="Ping payload")
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def test_bad_ping_interval(self):
"""A WebSocketApp handling of negative ping_interval"""
app = ws.WebSocketApp("wss://api-pub.bitfinex.com/ws/1")
self.assertRaises(
ws.WebSocketException,
app.run_forever,
ping_interval=-5,
sslopt={"cert_reqs": ssl.CERT_NONE},
)
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def test_bad_ping_timeout(self):
"""A WebSocketApp handling of negative ping_timeout"""
app = ws.WebSocketApp("wss://api-pub.bitfinex.com/ws/1")
self.assertRaises(
ws.WebSocketException,
app.run_forever,
ping_timeout=-3,
sslopt={"cert_reqs": ssl.CERT_NONE},
)
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def test_close_status_code(self):
"""Test extraction of close frame status code and close reason in WebSocketApp"""
def on_close(wsapp, close_status_code, close_msg):
print("on_close reached")
app = ws.WebSocketApp(
"wss://tsock.us1.twilio.com/v3/wsconnect", on_close=on_close
)
closeframe = ws.ABNF(
opcode=ws.ABNF.OPCODE_CLOSE, data=b"\x03\xe8no-init-from-client"
)
self.assertEqual([1000, "no-init-from-client"], app._get_close_args(closeframe))
closeframe = ws.ABNF(opcode=ws.ABNF.OPCODE_CLOSE, data=b"")
self.assertEqual([None, None], app._get_close_args(closeframe))
app2 = ws.WebSocketApp("wss://tsock.us1.twilio.com/v3/wsconnect")
closeframe = ws.ABNF(opcode=ws.ABNF.OPCODE_CLOSE, data=b"")
self.assertEqual([None, None], app2._get_close_args(closeframe))
self.assertRaises(
ws.WebSocketConnectionClosedException,
app.send,
data="test if connection is closed",
)
@unittest.skipUnless(
TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled"
)
def test_callback_function_exception(self):
"""Test callback function exception handling"""
exc = None
passed_app = None
def on_open(app):
raise RuntimeError("Callback failed")
def on_error(app, err):
nonlocal passed_app
passed_app = app
nonlocal exc
exc = err
def on_pong(app, _):
app.close()
app = ws.WebSocketApp(
f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}",
on_open=on_open,
on_error=on_error,
on_pong=on_pong,
)
app.run_forever(ping_interval=2, ping_timeout=1)
self.assertEqual(passed_app, app)
self.assertIsInstance(exc, RuntimeError)
self.assertEqual(str(exc), "Callback failed")
@unittest.skipUnless(
TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled"
)
def test_callback_method_exception(self):
"""Test callback method exception handling"""
class Callbacks:
def __init__(self):
self.exc = None
self.passed_app = None
self.app = ws.WebSocketApp(
f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}",
on_open=self.on_open,
on_error=self.on_error,
on_pong=self.on_pong,
)
self.app.run_forever(ping_interval=2, ping_timeout=1)
def on_open(self, _):
raise RuntimeError("Callback failed")
def on_error(self, app, err):
self.passed_app = app
self.exc = err
def on_pong(self, app, _):
app.close()
callbacks = Callbacks()
self.assertEqual(callbacks.passed_app, callbacks.app)
self.assertIsInstance(callbacks.exc, RuntimeError)
self.assertEqual(str(callbacks.exc), "Callback failed")
@unittest.skipUnless(
TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled"
)
def test_reconnect(self):
"""Test reconnect"""
pong_count = 0
exc = None
def on_error(_, err):
nonlocal exc
exc = err
def on_pong(app, _):
nonlocal pong_count
pong_count += 1
if pong_count == 1:
# First pong, shutdown socket, enforce read error
app.sock.shutdown()
if pong_count >= 2:
# Got second pong after reconnect
app.close()
app = ws.WebSocketApp(
f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}", on_pong=on_pong, on_error=on_error
)
app.run_forever(ping_interval=2, ping_timeout=1, reconnect=3)
self.assertEqual(pong_count, 2)
self.assertIsInstance(exc, ws.WebSocketTimeoutException)
self.assertEqual(str(exc), "ping/pong timed out")
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,123 @@
import unittest
from websocket._cookiejar import SimpleCookieJar
"""
test_cookiejar.py
websocket - WebSocket client library for Python
Copyright 2024 engn33r
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
class CookieJarTest(unittest.TestCase):
def test_add(self):
cookie_jar = SimpleCookieJar()
cookie_jar.add("")
self.assertFalse(
cookie_jar.jar, "Cookie with no domain should not be added to the jar"
)
cookie_jar = SimpleCookieJar()
cookie_jar.add("a=b")
self.assertFalse(
cookie_jar.jar, "Cookie with no domain should not be added to the jar"
)
cookie_jar = SimpleCookieJar()
cookie_jar.add("a=b; domain=.abc")
self.assertTrue(".abc" in cookie_jar.jar)
cookie_jar = SimpleCookieJar()
cookie_jar.add("a=b; domain=abc")
self.assertTrue(".abc" in cookie_jar.jar)
self.assertTrue("abc" not in cookie_jar.jar)
cookie_jar = SimpleCookieJar()
cookie_jar.add("a=b; c=d; domain=abc")
self.assertEqual(cookie_jar.get("abc"), "a=b; c=d")
self.assertEqual(cookie_jar.get(None), "")
cookie_jar = SimpleCookieJar()
cookie_jar.add("a=b; c=d; domain=abc")
cookie_jar.add("e=f; domain=abc")
self.assertEqual(cookie_jar.get("abc"), "a=b; c=d; e=f")
cookie_jar = SimpleCookieJar()
cookie_jar.add("a=b; c=d; domain=abc")
cookie_jar.add("e=f; domain=.abc")
self.assertEqual(cookie_jar.get("abc"), "a=b; c=d; e=f")
cookie_jar = SimpleCookieJar()
cookie_jar.add("a=b; c=d; domain=abc")
cookie_jar.add("e=f; domain=xyz")
self.assertEqual(cookie_jar.get("abc"), "a=b; c=d")
self.assertEqual(cookie_jar.get("xyz"), "e=f")
self.assertEqual(cookie_jar.get("something"), "")
def test_set(self):
cookie_jar = SimpleCookieJar()
cookie_jar.set("a=b")
self.assertFalse(
cookie_jar.jar, "Cookie with no domain should not be added to the jar"
)
cookie_jar = SimpleCookieJar()
cookie_jar.set("a=b; domain=.abc")
self.assertTrue(".abc" in cookie_jar.jar)
cookie_jar = SimpleCookieJar()
cookie_jar.set("a=b; domain=abc")
self.assertTrue(".abc" in cookie_jar.jar)
self.assertTrue("abc" not in cookie_jar.jar)
cookie_jar = SimpleCookieJar()
cookie_jar.set("a=b; c=d; domain=abc")
self.assertEqual(cookie_jar.get("abc"), "a=b; c=d")
cookie_jar = SimpleCookieJar()
cookie_jar.set("a=b; c=d; domain=abc")
cookie_jar.set("e=f; domain=abc")
self.assertEqual(cookie_jar.get("abc"), "e=f")
cookie_jar = SimpleCookieJar()
cookie_jar.set("a=b; c=d; domain=abc")
cookie_jar.set("e=f; domain=.abc")
self.assertEqual(cookie_jar.get("abc"), "e=f")
cookie_jar = SimpleCookieJar()
cookie_jar.set("a=b; c=d; domain=abc")
cookie_jar.set("e=f; domain=xyz")
self.assertEqual(cookie_jar.get("abc"), "a=b; c=d")
self.assertEqual(cookie_jar.get("xyz"), "e=f")
self.assertEqual(cookie_jar.get("something"), "")
def test_get(self):
cookie_jar = SimpleCookieJar()
cookie_jar.set("a=b; c=d; domain=abc.com")
self.assertEqual(cookie_jar.get("abc.com"), "a=b; c=d")
self.assertEqual(cookie_jar.get("x.abc.com"), "a=b; c=d")
self.assertEqual(cookie_jar.get("abc.com.es"), "")
self.assertEqual(cookie_jar.get("xabc.com"), "")
cookie_jar.set("a=b; c=d; domain=.abc.com")
self.assertEqual(cookie_jar.get("abc.com"), "a=b; c=d")
self.assertEqual(cookie_jar.get("x.abc.com"), "a=b; c=d")
self.assertEqual(cookie_jar.get("abc.com.es"), "")
self.assertEqual(cookie_jar.get("xabc.com"), "")
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,370 @@
# -*- coding: utf-8 -*-
#
import os
import os.path
import socket
import ssl
import unittest
import websocket
from websocket._exceptions import WebSocketProxyException, WebSocketException
from websocket._http import (
_get_addrinfo_list,
_start_proxied_socket,
_tunnel,
connect,
proxy_info,
read_headers,
HAVE_PYTHON_SOCKS,
)
"""
test_http.py
websocket - WebSocket client library for Python
Copyright 2024 engn33r
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
try:
from python_socks._errors import ProxyConnectionError, ProxyError, ProxyTimeoutError
except:
from websocket._http import ProxyConnectionError, ProxyError, ProxyTimeoutError
# Skip test to access the internet unless TEST_WITH_INTERNET == 1
TEST_WITH_INTERNET = os.environ.get("TEST_WITH_INTERNET", "0") == "1"
TEST_WITH_PROXY = os.environ.get("TEST_WITH_PROXY", "0") == "1"
# Skip tests relying on local websockets server unless LOCAL_WS_SERVER_PORT != -1
LOCAL_WS_SERVER_PORT = os.environ.get("LOCAL_WS_SERVER_PORT", "-1")
TEST_WITH_LOCAL_SERVER = LOCAL_WS_SERVER_PORT != "-1"
class SockMock:
def __init__(self):
self.data = []
self.sent = []
def add_packet(self, data):
self.data.append(data)
def gettimeout(self):
return None
def recv(self, bufsize):
if self.data:
e = self.data.pop(0)
if isinstance(e, Exception):
raise e
if len(e) > bufsize:
self.data.insert(0, e[bufsize:])
return e[:bufsize]
def send(self, data):
self.sent.append(data)
return len(data)
def close(self):
pass
class HeaderSockMock(SockMock):
def __init__(self, fname):
SockMock.__init__(self)
path = os.path.join(os.path.dirname(__file__), fname)
with open(path, "rb") as f:
self.add_packet(f.read())
class OptsList:
def __init__(self):
self.timeout = 1
self.sockopt = []
self.sslopt = {"cert_reqs": ssl.CERT_NONE}
class HttpTest(unittest.TestCase):
def test_read_header(self):
status, header, _ = read_headers(HeaderSockMock("data/header01.txt"))
self.assertEqual(status, 101)
self.assertEqual(header["connection"], "Upgrade")
# header02.txt is intentionally malformed
self.assertRaises(
WebSocketException, read_headers, HeaderSockMock("data/header02.txt")
)
def test_tunnel(self):
self.assertRaises(
WebSocketProxyException,
_tunnel,
HeaderSockMock("data/header01.txt"),
"example.com",
80,
("username", "password"),
)
self.assertRaises(
WebSocketProxyException,
_tunnel,
HeaderSockMock("data/header02.txt"),
"example.com",
80,
("username", "password"),
)
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def test_connect(self):
# Not currently testing an actual proxy connection, so just check whether proxy errors are raised. This requires internet for a DNS lookup
if HAVE_PYTHON_SOCKS:
# Need this check, otherwise case where python_socks is not installed triggers
# websocket._exceptions.WebSocketException: Python Socks is needed for SOCKS proxying but is not available
self.assertRaises(
(ProxyTimeoutError, OSError),
_start_proxied_socket,
"wss://example.com",
OptsList(),
proxy_info(
http_proxy_host="example.com",
http_proxy_port="8080",
proxy_type="socks4",
http_proxy_timeout=1,
),
)
self.assertRaises(
(ProxyTimeoutError, OSError),
_start_proxied_socket,
"wss://example.com",
OptsList(),
proxy_info(
http_proxy_host="example.com",
http_proxy_port="8080",
proxy_type="socks4a",
http_proxy_timeout=1,
),
)
self.assertRaises(
(ProxyTimeoutError, OSError),
_start_proxied_socket,
"wss://example.com",
OptsList(),
proxy_info(
http_proxy_host="example.com",
http_proxy_port="8080",
proxy_type="socks5",
http_proxy_timeout=1,
),
)
self.assertRaises(
(ProxyTimeoutError, OSError),
_start_proxied_socket,
"wss://example.com",
OptsList(),
proxy_info(
http_proxy_host="example.com",
http_proxy_port="8080",
proxy_type="socks5h",
http_proxy_timeout=1,
),
)
self.assertRaises(
ProxyConnectionError,
connect,
"wss://example.com",
OptsList(),
proxy_info(
http_proxy_host="127.0.0.1",
http_proxy_port=9999,
proxy_type="socks4",
http_proxy_timeout=1,
),
None,
)
self.assertRaises(
TypeError,
_get_addrinfo_list,
None,
80,
True,
proxy_info(
http_proxy_host="127.0.0.1", http_proxy_port="9999", proxy_type="http"
),
)
self.assertRaises(
TypeError,
_get_addrinfo_list,
None,
80,
True,
proxy_info(
http_proxy_host="127.0.0.1", http_proxy_port="9999", proxy_type="http"
),
)
self.assertRaises(
socket.timeout,
connect,
"wss://google.com",
OptsList(),
proxy_info(
http_proxy_host="8.8.8.8",
http_proxy_port=9999,
proxy_type="http",
http_proxy_timeout=1,
),
None,
)
self.assertEqual(
connect(
"wss://google.com",
OptsList(),
proxy_info(
http_proxy_host="8.8.8.8", http_proxy_port=8080, proxy_type="http"
),
True,
),
(True, ("google.com", 443, "/")),
)
# The following test fails on Mac OS with a gaierror, not an OverflowError
# self.assertRaises(OverflowError, connect, "wss://example.com", OptsList(), proxy_info(http_proxy_host="127.0.0.1", http_proxy_port=99999, proxy_type="socks4", timeout=2), False)
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
@unittest.skipUnless(
TEST_WITH_PROXY, "This test requires a HTTP proxy to be running on port 8899"
)
@unittest.skipUnless(
TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled"
)
def test_proxy_connect(self):
ws = websocket.WebSocket()
ws.connect(
f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}",
http_proxy_host="127.0.0.1",
http_proxy_port="8899",
proxy_type="http",
)
ws.send("Hello, Server")
server_response = ws.recv()
self.assertEqual(server_response, "Hello, Server")
# self.assertEqual(_start_proxied_socket("wss://api.bitfinex.com/ws/2", OptsList(), proxy_info(http_proxy_host="127.0.0.1", http_proxy_port="8899", proxy_type="http"))[1], ("api.bitfinex.com", 443, '/ws/2'))
self.assertEqual(
_get_addrinfo_list(
"api.bitfinex.com",
443,
True,
proxy_info(
http_proxy_host="127.0.0.1",
http_proxy_port="8899",
proxy_type="http",
),
),
(
socket.getaddrinfo(
"127.0.0.1", 8899, 0, socket.SOCK_STREAM, socket.SOL_TCP
),
True,
None,
),
)
self.assertEqual(
connect(
"wss://api.bitfinex.com/ws/2",
OptsList(),
proxy_info(
http_proxy_host="127.0.0.1", http_proxy_port=8899, proxy_type="http"
),
None,
)[1],
("api.bitfinex.com", 443, "/ws/2"),
)
# TODO: Test SOCKS4 and SOCK5 proxies with unit tests
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def test_sslopt(self):
ssloptions = {
"check_hostname": False,
"server_hostname": "ServerName",
"ssl_version": ssl.PROTOCOL_TLS_CLIENT,
"ciphers": "TLS_AES_256_GCM_SHA384:TLS_CHACHA20_POLY1305_SHA256:\
TLS_AES_128_GCM_SHA256:ECDHE-ECDSA-AES256-GCM-SHA384:\
ECDHE-RSA-AES256-GCM-SHA384:DHE-RSA-AES256-GCM-SHA384:\
ECDHE-ECDSA-CHACHA20-POLY1305:ECDHE-RSA-CHACHA20-POLY1305:\
DHE-RSA-CHACHA20-POLY1305:ECDHE-ECDSA-AES128-GCM-SHA256:\
ECDHE-RSA-AES128-GCM-SHA256:DHE-RSA-AES128-GCM-SHA256:\
ECDHE-ECDSA-AES256-SHA384:ECDHE-RSA-AES256-SHA384:\
DHE-RSA-AES256-SHA256:ECDHE-ECDSA-AES128-SHA256:\
ECDHE-RSA-AES128-SHA256:DHE-RSA-AES128-SHA256:\
ECDHE-ECDSA-AES256-SHA:ECDHE-RSA-AES256-SHA",
"ecdh_curve": "prime256v1",
}
ws_ssl1 = websocket.WebSocket(sslopt=ssloptions)
ws_ssl1.connect("wss://api.bitfinex.com/ws/2")
ws_ssl1.send("Hello")
ws_ssl1.close()
ws_ssl2 = websocket.WebSocket(sslopt={"check_hostname": True})
ws_ssl2.connect("wss://api.bitfinex.com/ws/2")
ws_ssl2.close
def test_proxy_info(self):
self.assertEqual(
proxy_info(
http_proxy_host="127.0.0.1", http_proxy_port="8080", proxy_type="http"
).proxy_protocol,
"http",
)
self.assertRaises(
ProxyError,
proxy_info,
http_proxy_host="127.0.0.1",
http_proxy_port="8080",
proxy_type="badval",
)
self.assertEqual(
proxy_info(
http_proxy_host="example.com", http_proxy_port="8080", proxy_type="http"
).proxy_host,
"example.com",
)
self.assertEqual(
proxy_info(
http_proxy_host="127.0.0.1", http_proxy_port="8080", proxy_type="http"
).proxy_port,
"8080",
)
self.assertEqual(
proxy_info(
http_proxy_host="127.0.0.1", http_proxy_port="8080", proxy_type="http"
).auth,
None,
)
self.assertEqual(
proxy_info(
http_proxy_host="127.0.0.1",
http_proxy_port="8080",
proxy_type="http",
http_proxy_auth=("my_username123", "my_pass321"),
).auth[0],
"my_username123",
)
self.assertEqual(
proxy_info(
http_proxy_host="127.0.0.1",
http_proxy_port="8080",
proxy_type="http",
http_proxy_auth=("my_username123", "my_pass321"),
).auth[1],
"my_pass321",
)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,464 @@
# -*- coding: utf-8 -*-
#
import os
import unittest
from websocket._url import (
_is_address_in_network,
_is_no_proxy_host,
get_proxy_info,
parse_url,
)
from websocket._exceptions import WebSocketProxyException
"""
test_url.py
websocket - WebSocket client library for Python
Copyright 2024 engn33r
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
class UrlTest(unittest.TestCase):
def test_address_in_network(self):
self.assertTrue(_is_address_in_network("127.0.0.1", "127.0.0.0/8"))
self.assertTrue(_is_address_in_network("127.1.0.1", "127.0.0.0/8"))
self.assertFalse(_is_address_in_network("127.1.0.1", "127.0.0.0/24"))
def test_parse_url(self):
p = parse_url("ws://www.example.com/r")
self.assertEqual(p[0], "www.example.com")
self.assertEqual(p[1], 80)
self.assertEqual(p[2], "/r")
self.assertEqual(p[3], False)
p = parse_url("ws://www.example.com/r/")
self.assertEqual(p[0], "www.example.com")
self.assertEqual(p[1], 80)
self.assertEqual(p[2], "/r/")
self.assertEqual(p[3], False)
p = parse_url("ws://www.example.com/")
self.assertEqual(p[0], "www.example.com")
self.assertEqual(p[1], 80)
self.assertEqual(p[2], "/")
self.assertEqual(p[3], False)
p = parse_url("ws://www.example.com")
self.assertEqual(p[0], "www.example.com")
self.assertEqual(p[1], 80)
self.assertEqual(p[2], "/")
self.assertEqual(p[3], False)
p = parse_url("ws://www.example.com:8080/r")
self.assertEqual(p[0], "www.example.com")
self.assertEqual(p[1], 8080)
self.assertEqual(p[2], "/r")
self.assertEqual(p[3], False)
p = parse_url("ws://www.example.com:8080/")
self.assertEqual(p[0], "www.example.com")
self.assertEqual(p[1], 8080)
self.assertEqual(p[2], "/")
self.assertEqual(p[3], False)
p = parse_url("ws://www.example.com:8080")
self.assertEqual(p[0], "www.example.com")
self.assertEqual(p[1], 8080)
self.assertEqual(p[2], "/")
self.assertEqual(p[3], False)
p = parse_url("wss://www.example.com:8080/r")
self.assertEqual(p[0], "www.example.com")
self.assertEqual(p[1], 8080)
self.assertEqual(p[2], "/r")
self.assertEqual(p[3], True)
p = parse_url("wss://www.example.com:8080/r?key=value")
self.assertEqual(p[0], "www.example.com")
self.assertEqual(p[1], 8080)
self.assertEqual(p[2], "/r?key=value")
self.assertEqual(p[3], True)
self.assertRaises(ValueError, parse_url, "http://www.example.com/r")
p = parse_url("ws://[2a03:4000:123:83::3]/r")
self.assertEqual(p[0], "2a03:4000:123:83::3")
self.assertEqual(p[1], 80)
self.assertEqual(p[2], "/r")
self.assertEqual(p[3], False)
p = parse_url("ws://[2a03:4000:123:83::3]:8080/r")
self.assertEqual(p[0], "2a03:4000:123:83::3")
self.assertEqual(p[1], 8080)
self.assertEqual(p[2], "/r")
self.assertEqual(p[3], False)
p = parse_url("wss://[2a03:4000:123:83::3]/r")
self.assertEqual(p[0], "2a03:4000:123:83::3")
self.assertEqual(p[1], 443)
self.assertEqual(p[2], "/r")
self.assertEqual(p[3], True)
p = parse_url("wss://[2a03:4000:123:83::3]:8080/r")
self.assertEqual(p[0], "2a03:4000:123:83::3")
self.assertEqual(p[1], 8080)
self.assertEqual(p[2], "/r")
self.assertEqual(p[3], True)
class IsNoProxyHostTest(unittest.TestCase):
def setUp(self):
self.no_proxy = os.environ.get("no_proxy", None)
if "no_proxy" in os.environ:
del os.environ["no_proxy"]
def tearDown(self):
if self.no_proxy:
os.environ["no_proxy"] = self.no_proxy
elif "no_proxy" in os.environ:
del os.environ["no_proxy"]
def test_match_all(self):
self.assertTrue(_is_no_proxy_host("any.websocket.org", ["*"]))
self.assertTrue(_is_no_proxy_host("192.168.0.1", ["*"]))
self.assertFalse(_is_no_proxy_host("192.168.0.1", ["192.168.1.1"]))
self.assertFalse(
_is_no_proxy_host("any.websocket.org", ["other.websocket.org"])
)
self.assertTrue(
_is_no_proxy_host("any.websocket.org", ["other.websocket.org", "*"])
)
os.environ["no_proxy"] = "*"
self.assertTrue(_is_no_proxy_host("any.websocket.org", None))
self.assertTrue(_is_no_proxy_host("192.168.0.1", None))
os.environ["no_proxy"] = "other.websocket.org, *"
self.assertTrue(_is_no_proxy_host("any.websocket.org", None))
def test_ip_address(self):
self.assertTrue(_is_no_proxy_host("127.0.0.1", ["127.0.0.1"]))
self.assertFalse(_is_no_proxy_host("127.0.0.2", ["127.0.0.1"]))
self.assertTrue(
_is_no_proxy_host("127.0.0.1", ["other.websocket.org", "127.0.0.1"])
)
self.assertFalse(
_is_no_proxy_host("127.0.0.2", ["other.websocket.org", "127.0.0.1"])
)
os.environ["no_proxy"] = "127.0.0.1"
self.assertTrue(_is_no_proxy_host("127.0.0.1", None))
self.assertFalse(_is_no_proxy_host("127.0.0.2", None))
os.environ["no_proxy"] = "other.websocket.org, 127.0.0.1"
self.assertTrue(_is_no_proxy_host("127.0.0.1", None))
self.assertFalse(_is_no_proxy_host("127.0.0.2", None))
def test_ip_address_in_range(self):
self.assertTrue(_is_no_proxy_host("127.0.0.1", ["127.0.0.0/8"]))
self.assertTrue(_is_no_proxy_host("127.0.0.2", ["127.0.0.0/8"]))
self.assertFalse(_is_no_proxy_host("127.1.0.1", ["127.0.0.0/24"]))
os.environ["no_proxy"] = "127.0.0.0/8"
self.assertTrue(_is_no_proxy_host("127.0.0.1", None))
self.assertTrue(_is_no_proxy_host("127.0.0.2", None))
os.environ["no_proxy"] = "127.0.0.0/24"
self.assertFalse(_is_no_proxy_host("127.1.0.1", None))
def test_hostname_match(self):
self.assertTrue(_is_no_proxy_host("my.websocket.org", ["my.websocket.org"]))
self.assertTrue(
_is_no_proxy_host(
"my.websocket.org", ["other.websocket.org", "my.websocket.org"]
)
)
self.assertFalse(_is_no_proxy_host("my.websocket.org", ["other.websocket.org"]))
os.environ["no_proxy"] = "my.websocket.org"
self.assertTrue(_is_no_proxy_host("my.websocket.org", None))
self.assertFalse(_is_no_proxy_host("other.websocket.org", None))
os.environ["no_proxy"] = "other.websocket.org, my.websocket.org"
self.assertTrue(_is_no_proxy_host("my.websocket.org", None))
def test_hostname_match_domain(self):
self.assertTrue(_is_no_proxy_host("any.websocket.org", [".websocket.org"]))
self.assertTrue(_is_no_proxy_host("my.other.websocket.org", [".websocket.org"]))
self.assertTrue(
_is_no_proxy_host(
"any.websocket.org", ["my.websocket.org", ".websocket.org"]
)
)
self.assertFalse(_is_no_proxy_host("any.websocket.com", [".websocket.org"]))
os.environ["no_proxy"] = ".websocket.org"
self.assertTrue(_is_no_proxy_host("any.websocket.org", None))
self.assertTrue(_is_no_proxy_host("my.other.websocket.org", None))
self.assertFalse(_is_no_proxy_host("any.websocket.com", None))
os.environ["no_proxy"] = "my.websocket.org, .websocket.org"
self.assertTrue(_is_no_proxy_host("any.websocket.org", None))
class ProxyInfoTest(unittest.TestCase):
def setUp(self):
self.http_proxy = os.environ.get("http_proxy", None)
self.https_proxy = os.environ.get("https_proxy", None)
self.no_proxy = os.environ.get("no_proxy", None)
if "http_proxy" in os.environ:
del os.environ["http_proxy"]
if "https_proxy" in os.environ:
del os.environ["https_proxy"]
if "no_proxy" in os.environ:
del os.environ["no_proxy"]
def tearDown(self):
if self.http_proxy:
os.environ["http_proxy"] = self.http_proxy
elif "http_proxy" in os.environ:
del os.environ["http_proxy"]
if self.https_proxy:
os.environ["https_proxy"] = self.https_proxy
elif "https_proxy" in os.environ:
del os.environ["https_proxy"]
if self.no_proxy:
os.environ["no_proxy"] = self.no_proxy
elif "no_proxy" in os.environ:
del os.environ["no_proxy"]
def test_proxy_from_args(self):
self.assertRaises(
WebSocketProxyException,
get_proxy_info,
"echo.websocket.events",
False,
proxy_host="localhost",
)
self.assertEqual(
get_proxy_info(
"echo.websocket.events", False, proxy_host="localhost", proxy_port=3128
),
("localhost", 3128, None),
)
self.assertEqual(
get_proxy_info(
"echo.websocket.events", True, proxy_host="localhost", proxy_port=3128
),
("localhost", 3128, None),
)
self.assertEqual(
get_proxy_info(
"echo.websocket.events",
False,
proxy_host="localhost",
proxy_port=9001,
proxy_auth=("a", "b"),
),
("localhost", 9001, ("a", "b")),
)
self.assertEqual(
get_proxy_info(
"echo.websocket.events",
False,
proxy_host="localhost",
proxy_port=3128,
proxy_auth=("a", "b"),
),
("localhost", 3128, ("a", "b")),
)
self.assertEqual(
get_proxy_info(
"echo.websocket.events",
True,
proxy_host="localhost",
proxy_port=8765,
proxy_auth=("a", "b"),
),
("localhost", 8765, ("a", "b")),
)
self.assertEqual(
get_proxy_info(
"echo.websocket.events",
True,
proxy_host="localhost",
proxy_port=3128,
proxy_auth=("a", "b"),
),
("localhost", 3128, ("a", "b")),
)
self.assertEqual(
get_proxy_info(
"echo.websocket.events",
True,
proxy_host="localhost",
proxy_port=3128,
no_proxy=["example.com"],
proxy_auth=("a", "b"),
),
("localhost", 3128, ("a", "b")),
)
self.assertEqual(
get_proxy_info(
"echo.websocket.events",
True,
proxy_host="localhost",
proxy_port=3128,
no_proxy=["echo.websocket.events"],
proxy_auth=("a", "b"),
),
(None, 0, None),
)
self.assertEqual(
get_proxy_info(
"echo.websocket.events",
True,
proxy_host="localhost",
proxy_port=3128,
no_proxy=[".websocket.events"],
),
(None, 0, None),
)
def test_proxy_from_env(self):
os.environ["http_proxy"] = "http://localhost/"
self.assertEqual(
get_proxy_info("echo.websocket.events", False), ("localhost", None, None)
)
os.environ["http_proxy"] = "http://localhost:3128/"
self.assertEqual(
get_proxy_info("echo.websocket.events", False), ("localhost", 3128, None)
)
os.environ["http_proxy"] = "http://localhost/"
os.environ["https_proxy"] = "http://localhost2/"
self.assertEqual(
get_proxy_info("echo.websocket.events", False), ("localhost", None, None)
)
os.environ["http_proxy"] = "http://localhost:3128/"
os.environ["https_proxy"] = "http://localhost2:3128/"
self.assertEqual(
get_proxy_info("echo.websocket.events", False), ("localhost", 3128, None)
)
os.environ["http_proxy"] = "http://localhost/"
os.environ["https_proxy"] = "http://localhost2/"
self.assertEqual(
get_proxy_info("echo.websocket.events", True), ("localhost2", None, None)
)
os.environ["http_proxy"] = "http://localhost:3128/"
os.environ["https_proxy"] = "http://localhost2:3128/"
self.assertEqual(
get_proxy_info("echo.websocket.events", True), ("localhost2", 3128, None)
)
os.environ["http_proxy"] = ""
os.environ["https_proxy"] = "http://localhost2/"
self.assertEqual(
get_proxy_info("echo.websocket.events", True), ("localhost2", None, None)
)
self.assertEqual(
get_proxy_info("echo.websocket.events", False), (None, 0, None)
)
os.environ["http_proxy"] = ""
os.environ["https_proxy"] = "http://localhost2:3128/"
self.assertEqual(
get_proxy_info("echo.websocket.events", True), ("localhost2", 3128, None)
)
self.assertEqual(
get_proxy_info("echo.websocket.events", False), (None, 0, None)
)
os.environ["http_proxy"] = "http://localhost/"
os.environ["https_proxy"] = ""
self.assertEqual(get_proxy_info("echo.websocket.events", True), (None, 0, None))
self.assertEqual(
get_proxy_info("echo.websocket.events", False), ("localhost", None, None)
)
os.environ["http_proxy"] = "http://localhost:3128/"
os.environ["https_proxy"] = ""
self.assertEqual(get_proxy_info("echo.websocket.events", True), (None, 0, None))
self.assertEqual(
get_proxy_info("echo.websocket.events", False), ("localhost", 3128, None)
)
os.environ["http_proxy"] = "http://a:b@localhost/"
self.assertEqual(
get_proxy_info("echo.websocket.events", False),
("localhost", None, ("a", "b")),
)
os.environ["http_proxy"] = "http://a:b@localhost:3128/"
self.assertEqual(
get_proxy_info("echo.websocket.events", False),
("localhost", 3128, ("a", "b")),
)
os.environ["http_proxy"] = "http://a:b@localhost/"
os.environ["https_proxy"] = "http://a:b@localhost2/"
self.assertEqual(
get_proxy_info("echo.websocket.events", False),
("localhost", None, ("a", "b")),
)
os.environ["http_proxy"] = "http://a:b@localhost:3128/"
os.environ["https_proxy"] = "http://a:b@localhost2:3128/"
self.assertEqual(
get_proxy_info("echo.websocket.events", False),
("localhost", 3128, ("a", "b")),
)
os.environ["http_proxy"] = "http://a:b@localhost/"
os.environ["https_proxy"] = "http://a:b@localhost2/"
self.assertEqual(
get_proxy_info("echo.websocket.events", True),
("localhost2", None, ("a", "b")),
)
os.environ["http_proxy"] = "http://a:b@localhost:3128/"
os.environ["https_proxy"] = "http://a:b@localhost2:3128/"
self.assertEqual(
get_proxy_info("echo.websocket.events", True),
("localhost2", 3128, ("a", "b")),
)
os.environ["http_proxy"] = (
"http://john%40example.com:P%40SSWORD@localhost:3128/"
)
os.environ["https_proxy"] = (
"http://john%40example.com:P%40SSWORD@localhost2:3128/"
)
self.assertEqual(
get_proxy_info("echo.websocket.events", True),
("localhost2", 3128, ("john@example.com", "P@SSWORD")),
)
os.environ["http_proxy"] = "http://a:b@localhost/"
os.environ["https_proxy"] = "http://a:b@localhost2/"
os.environ["no_proxy"] = "example1.com,example2.com"
self.assertEqual(
get_proxy_info("example.1.com", True), ("localhost2", None, ("a", "b"))
)
os.environ["http_proxy"] = "http://a:b@localhost:3128/"
os.environ["https_proxy"] = "http://a:b@localhost2:3128/"
os.environ["no_proxy"] = "example1.com,example2.com, echo.websocket.events"
self.assertEqual(get_proxy_info("echo.websocket.events", True), (None, 0, None))
os.environ["http_proxy"] = "http://a:b@localhost:3128/"
os.environ["https_proxy"] = "http://a:b@localhost2:3128/"
os.environ["no_proxy"] = "example1.com,example2.com, .websocket.events"
self.assertEqual(get_proxy_info("echo.websocket.events", True), (None, 0, None))
os.environ["http_proxy"] = "http://a:b@localhost:3128/"
os.environ["https_proxy"] = "http://a:b@localhost2:3128/"
os.environ["no_proxy"] = "127.0.0.0/8, 192.168.0.0/16"
self.assertEqual(get_proxy_info("127.0.0.1", False), (None, 0, None))
self.assertEqual(get_proxy_info("192.168.1.1", False), (None, 0, None))
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,497 @@
# -*- coding: utf-8 -*-
#
import os
import os.path
import socket
import unittest
from base64 import decodebytes as base64decode
import websocket as ws
from websocket._exceptions import WebSocketBadStatusException, WebSocketAddressException
from websocket._handshake import _create_sec_websocket_key
from websocket._handshake import _validate as _validate_header
from websocket._http import read_headers
from websocket._utils import validate_utf8
"""
test_websocket.py
websocket - WebSocket client library for Python
Copyright 2024 engn33r
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
try:
import ssl
except ImportError:
# dummy class of SSLError for ssl none-support environment.
class SSLError(Exception):
pass
# Skip test to access the internet unless TEST_WITH_INTERNET == 1
TEST_WITH_INTERNET = os.environ.get("TEST_WITH_INTERNET", "0") == "1"
# Skip tests relying on local websockets server unless LOCAL_WS_SERVER_PORT != -1
LOCAL_WS_SERVER_PORT = os.environ.get("LOCAL_WS_SERVER_PORT", "-1")
TEST_WITH_LOCAL_SERVER = LOCAL_WS_SERVER_PORT != "-1"
TRACEABLE = True
def create_mask_key(_):
return "abcd"
class SockMock:
def __init__(self):
self.data = []
self.sent = []
def add_packet(self, data):
self.data.append(data)
def gettimeout(self):
return None
def recv(self, bufsize):
if self.data:
e = self.data.pop(0)
if isinstance(e, Exception):
raise e
if len(e) > bufsize:
self.data.insert(0, e[bufsize:])
return e[:bufsize]
def send(self, data):
self.sent.append(data)
return len(data)
def close(self):
pass
class HeaderSockMock(SockMock):
def __init__(self, fname):
SockMock.__init__(self)
path = os.path.join(os.path.dirname(__file__), fname)
with open(path, "rb") as f:
self.add_packet(f.read())
class WebSocketTest(unittest.TestCase):
def setUp(self):
ws.enableTrace(TRACEABLE)
def tearDown(self):
pass
def test_default_timeout(self):
self.assertEqual(ws.getdefaulttimeout(), None)
ws.setdefaulttimeout(10)
self.assertEqual(ws.getdefaulttimeout(), 10)
ws.setdefaulttimeout(None)
def test_ws_key(self):
key = _create_sec_websocket_key()
self.assertTrue(key != 24)
self.assertTrue("¥n" not in key)
def test_nonce(self):
"""WebSocket key should be a random 16-byte nonce."""
key = _create_sec_websocket_key()
nonce = base64decode(key.encode("utf-8"))
self.assertEqual(16, len(nonce))
def test_ws_utils(self):
key = "c6b8hTg4EeGb2gQMztV1/g=="
required_header = {
"upgrade": "websocket",
"connection": "upgrade",
"sec-websocket-accept": "Kxep+hNu9n51529fGidYu7a3wO0=",
}
self.assertEqual(_validate_header(required_header, key, None), (True, None))
header = required_header.copy()
header["upgrade"] = "http"
self.assertEqual(_validate_header(header, key, None), (False, None))
del header["upgrade"]
self.assertEqual(_validate_header(header, key, None), (False, None))
header = required_header.copy()
header["connection"] = "something"
self.assertEqual(_validate_header(header, key, None), (False, None))
del header["connection"]
self.assertEqual(_validate_header(header, key, None), (False, None))
header = required_header.copy()
header["sec-websocket-accept"] = "something"
self.assertEqual(_validate_header(header, key, None), (False, None))
del header["sec-websocket-accept"]
self.assertEqual(_validate_header(header, key, None), (False, None))
header = required_header.copy()
header["sec-websocket-protocol"] = "sub1"
self.assertEqual(
_validate_header(header, key, ["sub1", "sub2"]), (True, "sub1")
)
# This case will print out a logging error using the error() function, but that is expected
self.assertEqual(_validate_header(header, key, ["sub2", "sub3"]), (False, None))
header = required_header.copy()
header["sec-websocket-protocol"] = "sUb1"
self.assertEqual(
_validate_header(header, key, ["Sub1", "suB2"]), (True, "sub1")
)
header = required_header.copy()
# This case will print out a logging error using the error() function, but that is expected
self.assertEqual(_validate_header(header, key, ["Sub1", "suB2"]), (False, None))
def test_read_header(self):
status, header, _ = read_headers(HeaderSockMock("data/header01.txt"))
self.assertEqual(status, 101)
self.assertEqual(header["connection"], "Upgrade")
status, header, _ = read_headers(HeaderSockMock("data/header03.txt"))
self.assertEqual(status, 101)
self.assertEqual(header["connection"], "Upgrade, Keep-Alive")
HeaderSockMock("data/header02.txt")
self.assertRaises(
ws.WebSocketException, read_headers, HeaderSockMock("data/header02.txt")
)
def test_send(self):
# TODO: add longer frame data
sock = ws.WebSocket()
sock.set_mask_key(create_mask_key)
s = sock.sock = HeaderSockMock("data/header01.txt")
sock.send("Hello")
self.assertEqual(s.sent[0], b"\x81\x85abcd)\x07\x0f\x08\x0e")
sock.send("こんにちは")
self.assertEqual(
s.sent[1],
b"\x81\x8fabcd\x82\xe3\xf0\x87\xe3\xf1\x80\xe5\xca\x81\xe2\xc5\x82\xe3\xcc",
)
# sock.send("x" * 5000)
# self.assertEqual(s.sent[1], b'\x81\x8fabcd\x82\xe3\xf0\x87\xe3\xf1\x80\xe5\xca\x81\xe2\xc5\x82\xe3\xcc")
self.assertEqual(sock.send_binary(b"1111111111101"), 19)
def test_recv(self):
# TODO: add longer frame data
sock = ws.WebSocket()
s = sock.sock = SockMock()
something = (
b"\x81\x8fabcd\x82\xe3\xf0\x87\xe3\xf1\x80\xe5\xca\x81\xe2\xc5\x82\xe3\xcc"
)
s.add_packet(something)
data = sock.recv()
self.assertEqual(data, "こんにちは")
s.add_packet(b"\x81\x85abcd)\x07\x0f\x08\x0e")
data = sock.recv()
self.assertEqual(data, "Hello")
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def test_iter(self):
count = 2
s = ws.create_connection("wss://api.bitfinex.com/ws/2")
s.send('{"event": "subscribe", "channel": "ticker"}')
for _ in s:
count -= 1
if count == 0:
break
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def test_next(self):
sock = ws.create_connection("wss://api.bitfinex.com/ws/2")
self.assertEqual(str, type(next(sock)))
def test_internal_recv_strict(self):
sock = ws.WebSocket()
s = sock.sock = SockMock()
s.add_packet(b"foo")
s.add_packet(socket.timeout())
s.add_packet(b"bar")
# s.add_packet(SSLError("The read operation timed out"))
s.add_packet(b"baz")
with self.assertRaises(ws.WebSocketTimeoutException):
sock.frame_buffer.recv_strict(9)
# with self.assertRaises(SSLError):
# data = sock._recv_strict(9)
data = sock.frame_buffer.recv_strict(9)
self.assertEqual(data, b"foobarbaz")
with self.assertRaises(ws.WebSocketConnectionClosedException):
sock.frame_buffer.recv_strict(1)
def test_recv_timeout(self):
sock = ws.WebSocket()
s = sock.sock = SockMock()
s.add_packet(b"\x81")
s.add_packet(socket.timeout())
s.add_packet(b"\x8dabcd\x29\x07\x0f\x08\x0e")
s.add_packet(socket.timeout())
s.add_packet(b"\x4e\x43\x33\x0e\x10\x0f\x00\x40")
with self.assertRaises(ws.WebSocketTimeoutException):
sock.recv()
with self.assertRaises(ws.WebSocketTimeoutException):
sock.recv()
data = sock.recv()
self.assertEqual(data, "Hello, World!")
with self.assertRaises(ws.WebSocketConnectionClosedException):
sock.recv()
def test_recv_with_simple_fragmentation(self):
sock = ws.WebSocket()
s = sock.sock = SockMock()
# OPCODE=TEXT, FIN=0, MSG="Brevity is "
s.add_packet(b"\x01\x8babcd#\x10\x06\x12\x08\x16\x1aD\x08\x11C")
# OPCODE=CONT, FIN=1, MSG="the soul of wit"
s.add_packet(b"\x80\x8fabcd\x15\n\x06D\x12\r\x16\x08A\r\x05D\x16\x0b\x17")
data = sock.recv()
self.assertEqual(data, "Brevity is the soul of wit")
with self.assertRaises(ws.WebSocketConnectionClosedException):
sock.recv()
def test_recv_with_fire_event_of_fragmentation(self):
sock = ws.WebSocket(fire_cont_frame=True)
s = sock.sock = SockMock()
# OPCODE=TEXT, FIN=0, MSG="Brevity is "
s.add_packet(b"\x01\x8babcd#\x10\x06\x12\x08\x16\x1aD\x08\x11C")
# OPCODE=CONT, FIN=0, MSG="Brevity is "
s.add_packet(b"\x00\x8babcd#\x10\x06\x12\x08\x16\x1aD\x08\x11C")
# OPCODE=CONT, FIN=1, MSG="the soul of wit"
s.add_packet(b"\x80\x8fabcd\x15\n\x06D\x12\r\x16\x08A\r\x05D\x16\x0b\x17")
_, data = sock.recv_data()
self.assertEqual(data, b"Brevity is ")
_, data = sock.recv_data()
self.assertEqual(data, b"Brevity is ")
_, data = sock.recv_data()
self.assertEqual(data, b"the soul of wit")
# OPCODE=CONT, FIN=0, MSG="Brevity is "
s.add_packet(b"\x80\x8babcd#\x10\x06\x12\x08\x16\x1aD\x08\x11C")
with self.assertRaises(ws.WebSocketException):
sock.recv_data()
with self.assertRaises(ws.WebSocketConnectionClosedException):
sock.recv()
def test_close(self):
sock = ws.WebSocket()
sock.connected = True
sock.close
sock = ws.WebSocket()
s = sock.sock = SockMock()
sock.connected = True
s.add_packet(b"\x88\x80\x17\x98p\x84")
sock.recv()
self.assertEqual(sock.connected, False)
def test_recv_cont_fragmentation(self):
sock = ws.WebSocket()
s = sock.sock = SockMock()
# OPCODE=CONT, FIN=1, MSG="the soul of wit"
s.add_packet(b"\x80\x8fabcd\x15\n\x06D\x12\r\x16\x08A\r\x05D\x16\x0b\x17")
self.assertRaises(ws.WebSocketException, sock.recv)
def test_recv_with_prolonged_fragmentation(self):
sock = ws.WebSocket()
s = sock.sock = SockMock()
# OPCODE=TEXT, FIN=0, MSG="Once more unto the breach, "
s.add_packet(
b"\x01\x9babcd.\x0c\x00\x01A\x0f\x0c\x16\x04B\x16\n\x15\rC\x10\t\x07C\x06\x13\x07\x02\x07\tNC"
)
# OPCODE=CONT, FIN=0, MSG="dear friends, "
s.add_packet(b"\x00\x8eabcd\x05\x07\x02\x16A\x04\x11\r\x04\x0c\x07\x17MB")
# OPCODE=CONT, FIN=1, MSG="once more"
s.add_packet(b"\x80\x89abcd\x0e\x0c\x00\x01A\x0f\x0c\x16\x04")
data = sock.recv()
self.assertEqual(data, "Once more unto the breach, dear friends, once more")
with self.assertRaises(ws.WebSocketConnectionClosedException):
sock.recv()
def test_recv_with_fragmentation_and_control_frame(self):
sock = ws.WebSocket()
sock.set_mask_key(create_mask_key)
s = sock.sock = SockMock()
# OPCODE=TEXT, FIN=0, MSG="Too much "
s.add_packet(b"\x01\x89abcd5\r\x0cD\x0c\x17\x00\x0cA")
# OPCODE=PING, FIN=1, MSG="Please PONG this"
s.add_packet(b"\x89\x90abcd1\x0e\x06\x05\x12\x07C4.,$D\x15\n\n\x17")
# OPCODE=CONT, FIN=1, MSG="of a good thing"
s.add_packet(b"\x80\x8fabcd\x0e\x04C\x05A\x05\x0c\x0b\x05B\x17\x0c\x08\x0c\x04")
data = sock.recv()
self.assertEqual(data, "Too much of a good thing")
with self.assertRaises(ws.WebSocketConnectionClosedException):
sock.recv()
self.assertEqual(
s.sent[0], b"\x8a\x90abcd1\x0e\x06\x05\x12\x07C4.,$D\x15\n\n\x17"
)
@unittest.skipUnless(
TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled"
)
def test_websocket(self):
s = ws.create_connection(f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}")
self.assertNotEqual(s, None)
s.send("Hello, World")
result = s.next()
s.fileno()
self.assertEqual(result, "Hello, World")
s.send("こにゃにゃちは、世界")
result = s.recv()
self.assertEqual(result, "こにゃにゃちは、世界")
self.assertRaises(ValueError, s.send_close, -1, "")
s.close()
@unittest.skipUnless(
TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled"
)
def test_ping_pong(self):
s = ws.create_connection(f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}")
self.assertNotEqual(s, None)
s.ping("Hello")
s.pong("Hi")
s.close()
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def test_support_redirect(self):
s = ws.WebSocket()
self.assertRaises(WebSocketBadStatusException, s.connect, "ws://google.com/")
# Need to find a URL that has a redirect code leading to a websocket
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def test_secure_websocket(self):
s = ws.create_connection("wss://api.bitfinex.com/ws/2")
self.assertNotEqual(s, None)
self.assertTrue(isinstance(s.sock, ssl.SSLSocket))
self.assertEqual(s.getstatus(), 101)
self.assertNotEqual(s.getheaders(), None)
s.settimeout(10)
self.assertEqual(s.gettimeout(), 10)
self.assertEqual(s.getsubprotocol(), None)
s.abort()
@unittest.skipUnless(
TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled"
)
def test_websocket_with_custom_header(self):
s = ws.create_connection(
f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}",
headers={"User-Agent": "PythonWebsocketClient"},
)
self.assertNotEqual(s, None)
self.assertEqual(s.getsubprotocol(), None)
s.send("Hello, World")
result = s.recv()
self.assertEqual(result, "Hello, World")
self.assertRaises(ValueError, s.close, -1, "")
s.close()
@unittest.skipUnless(
TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled"
)
def test_after_close(self):
s = ws.create_connection(f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}")
self.assertNotEqual(s, None)
s.close()
self.assertRaises(ws.WebSocketConnectionClosedException, s.send, "Hello")
self.assertRaises(ws.WebSocketConnectionClosedException, s.recv)
class SockOptTest(unittest.TestCase):
@unittest.skipUnless(
TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled"
)
def test_sockopt(self):
sockopt = ((socket.IPPROTO_TCP, socket.TCP_NODELAY, 1),)
s = ws.create_connection(
f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}", sockopt=sockopt
)
self.assertNotEqual(
s.sock.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY), 0
)
s.close()
class UtilsTest(unittest.TestCase):
def test_utf8_validator(self):
state = validate_utf8(b"\xf0\x90\x80\x80")
self.assertEqual(state, True)
state = validate_utf8(
b"\xce\xba\xe1\xbd\xb9\xcf\x83\xce\xbc\xce\xb5\xed\xa0\x80edited"
)
self.assertEqual(state, False)
state = validate_utf8(b"")
self.assertEqual(state, True)
class HandshakeTest(unittest.TestCase):
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def test_http_ssl(self):
websock1 = ws.WebSocket(
sslopt={"cert_chain": ssl.get_default_verify_paths().capath},
enable_multithread=False,
)
self.assertRaises(ValueError, websock1.connect, "wss://api.bitfinex.com/ws/2")
websock2 = ws.WebSocket(sslopt={"certfile": "myNonexistentCertFile"})
self.assertRaises(
FileNotFoundError, websock2.connect, "wss://api.bitfinex.com/ws/2"
)
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def test_manual_headers(self):
websock3 = ws.WebSocket(
sslopt={
"ca_certs": ssl.get_default_verify_paths().cafile,
"ca_cert_path": ssl.get_default_verify_paths().capath,
}
)
self.assertRaises(
WebSocketBadStatusException,
websock3.connect,
"wss://api.bitfinex.com/ws/2",
cookie="chocolate",
origin="testing_websockets.com",
host="echo.websocket.events/websocket-client-test",
subprotocols=["testproto"],
connection="Upgrade",
header={
"CustomHeader1": "123",
"Cookie": "TestValue",
"Sec-WebSocket-Key": "k9kFAUWNAMmf5OEMfTlOEA==",
"Sec-WebSocket-Protocol": "newprotocol",
},
)
def test_ipv6(self):
websock2 = ws.WebSocket()
self.assertRaises(ValueError, websock2.connect, "2001:4860:4860::8888")
def test_bad_urls(self):
websock3 = ws.WebSocket()
self.assertRaises(ValueError, websock3.connect, "ws//example.com")
self.assertRaises(WebSocketAddressException, websock3.connect, "ws://example")
self.assertRaises(ValueError, websock3.connect, "example.com")
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,199 @@
# Python imports
from os import path
import json
# Lib imports
import gi
gi.require_version('Gtk', '3.0')
gi.require_version('GtkSource', '4')
from gi.repository import Gtk
from gi.repository import GLib
from gi.repository import GtkSource
# Application imports
from .provider import Provider
class LSPManager(Gtk.Dialog):
def __init__(self):
super(LSPManager, self).__init__()
self._SCRIPT_PTH: str = path.dirname( path.realpath(__file__) )
self._USER_HOME: str = path.expanduser('~')
self._LSP_SERVERS_CONFIG: str = ""
self.servers_config: dict = {}
self.provider: Provider = Provider()
self.parent = None
self.source_view = None
self._setup_styling()
self._setup_signals()
self._subscribe_to_events()
self._load_widgets()
def _setup_styling(self):
self.set_modal(True)
self.set_decorated(False)
self.set_vexpand(True)
self.set_hexpand(True)
def _setup_signals(self):
self.connect("show", self._show)
def _subscribe_to_events(self):
...
def _load_widgets(self):
content_area = self.get_content_area()
self.main_box = Gtk.Grid()
self.path_entry = Gtk.SearchEntry()
self.path_bttn = Gtk.FileChooserButton.new(
title = "Workspace Folder",
action = Gtk.FileChooserAction.SELECT_FOLDER
)
self.combo_box = Gtk.ComboBoxText()
self.hide_bttn = Gtk.Button(label = "X")
bttn_box = Gtk.Box()
create_client_bttn = Gtk.Button(label = "Create Language Client")
close_client_bttn = Gtk.Button(label = "Close Language Client")
self.path_entry.set_can_focus(False)
self.path_entry.set_placeholder_text("Workspace Folder...")
self.path_entry.connect("changed", self._path_changed, bttn_box)
self.path_bttn.connect("file-set", self._file_set)
self.path_bttn.set_halign(Gtk.Align.FILL)
self.hide_bttn.connect("clicked", lambda widget: self.hide())
create_client_bttn.connect("clicked", self.create_client, close_client_bttn)
close_client_bttn.connect("clicked", self.close_client, create_client_bttn)
self.main_box.set_column_spacing(15)
self.main_box.set_row_spacing(15)
bttn_box.pack_start(create_client_bttn, False, False, 0)
bttn_box.pack_start(close_client_bttn, False, False, 0)
self.main_box.attach(child = self.path_entry, left = 0, top = 0, width = 4, height = 1)
self.main_box.attach(child = self.path_bttn, left = 4, top = 0, width = 1, height = 1)
self.main_box.attach(child = self.combo_box, left = 5, top = 0, width = 1, height = 1)
self.main_box.attach(child = self.hide_bttn, left = 6, top = 0, width = 1, height = 1)
self.main_box.attach(child = bttn_box, left = 0, top = 1, width = 1, height = 1)
content_area.set_vexpand(True)
content_area.set_hexpand(True)
content_area.add(self.main_box)
content_area.show_all()
close_client_bttn.hide()
bttn_box.hide()
def _show(self, widget):
GLib.idle_add(self.path_entry.grab_focus)
def _map_resize(self, widget, parent):
parent_x, parent_y = parent.get_position()
parent_width, parent_height = parent.get_size()
if parent_width == 0 or parent_height == 0: return
width = int(parent_width * 0.75)
height = int(parent_height * 0.75)
widget.resize(width, height)
x = parent_x + (parent_width - width) // 2
y = parent_y + (parent_height - height) // 2
widget.move(x, y)
def _path_changed(self, widget, buttons_widget):
fpath = widget.get_text()
if not fpath:
buttons_widget.hide()
return
buttons_widget.show()
def _file_set(self, widget):
self.path_entry.set_text(
widget.get_filename()
)
self.load_lsp_servers_config_placeholders()
def map_parent_resize_event(self, parent):
parent.connect("size-allocate", lambda w, r: self._map_resize(self, parent))
def set_source_view(self, source_view):
scrolled_win = Gtk.ScrolledWindow()
lang_manager = GtkSource.LanguageManager()
buffer = source_view.get_buffer()
language = lang_manager.get_language("json")
self.source_view = source_view
buffer.set_language(language)
buffer.set_style_scheme(self.source_view.syntax_theme)
scrolled_win.set_hexpand(True)
scrolled_win.set_vexpand(True)
scrolled_win.add(self.source_view)
self.main_box.attach(child = scrolled_win, left = 0, top = 2, width = 7, height = 1)
scrolled_win.show_all()
def load_lsp_servers_config(self):
with open(f"{self._SCRIPT_PTH}/configs/lsp-servers-config.json") as file:
self._LSP_SERVERS_CONFIG = file.read()
def load_lsp_servers_config_placeholders(self):
data = self._LSP_SERVERS_CONFIG \
.replace("{user.home}", self._USER_HOME) \
.replace("{workspace.folder}", self.path_entry.get_text())
self.servers_config = json.loads(data)
buffer = self.source_view.get_buffer()
start_itr, \
end_itr = buffer.get_bounds()
buffer.delete(start_itr, end_itr)
buffer.insert(start_itr, data, -1)
self.set_language_combo_box( self.servers_config.keys() )
def set_language_combo_box(self, lang_ids: list[str]):
for lang_id in lang_ids:
self.combo_box.append_text(lang_id)
def create_client(self, widget, sibling):
buffer = self.source_view.get_buffer()
lang_id = self.combo_box.get_active_text()
if not lang_id: return
if not lang_id in self.servers_config: return
self.servers_config = json.loads( buffer.get_text( *buffer.get_bounds() ) )
init_opts = self.servers_config[lang_id]["initialization-options"]
workspace_dir = self.path_entry.get_text()
result = self.provider.response_cache.create_client(
lang_id, workspace_dir, init_opts
)
if not result: return
widget.hide()
sibling.show()
def close_client(self, widget, sibling):
lang_id = self.combo_box.get_active_text()
if not lang_id: return
result = self.provider.response_cache.close_client(lang_id)
if not result: return
widget.hide()
sibling.show()

View File

@@ -0,0 +1,7 @@
{
"name": "LSP Manager",
"author": "ITDominator",
"version": "0.0.1",
"support": "",
"requests": {}
}

View File

@@ -0,0 +1,3 @@
"""
Pligin Module Mixins
"""

View File

@@ -0,0 +1,144 @@
# Python imports
# Lib imports
import gi
from gi.repository import GLib
# Application imports
from libs.event_factory import Code_Event_Types
class LSPClientEventsMixin:
def process_file_load(self, event: Code_Event_Types.AddedNewFileEvent):
lang_id = event.file.ftype
if lang_id not in self.clients:
logger.debug(f"No LSP client for '{lang_id}', skipping didOpen")
return
controller = self.clients[lang_id]
fpath = event.file.fpath
uri = f"file://{fpath}" if not fpath.startswith("file://") else fpath
buffer = event.file.buffer
text = buffer.get_text(*buffer.get_bounds())
self._last_active_language_id = lang_id
controller._lsp_did_open({
"uri": uri,
"language_id": lang_id,
"text": text
})
def process_file_close(self, event: Code_Event_Types.RemovedFileEvent):
lang_id = event.file.ftype
if lang_id not in self.clients:
logger.debug(f"No LSP client for '{lang_id}', skipping didClose")
return
controller = self.clients[lang_id]
fpath = event.file.fpath
uri = f"file://{fpath}" if not fpath.startswith("file://") else fpath
controller._lsp_did_close({"uri": uri})
def process_file_save(self, event: Code_Event_Types.SavedFileEvent):
lang_id = event.file.ftype
if lang_id not in self.clients:
logger.debug(f"No LSP client for '{lang_id}', skipping didSave")
return
controller = self.clients[lang_id]
fpath = event.file.fpath
uri = f"file://{fpath}" if not fpath.startswith("file://") else fpath
buffer = event.file.buffer
text = buffer.get_text(*buffer.get_bounds())
self._last_active_language_id = lang_id
controller._lsp_did_save({"uri": uri, "text": text})
def process_file_change(self, event: Code_Event_Types.TextChangedEvent):
self._clear_delayed_cache_refresh_trigger()
lang_id = event.file.ftype
if lang_id not in self.clients:
logger.debug(f"No LSP client for '{lang_id}', skipping didChange")
return
controller = self.clients[lang_id]
fpath = event.file.fpath
uri = f"file://{fpath}" if not fpath.startswith("file://") else fpath
buffer = event.file.buffer
text = buffer.get_text(*buffer.get_bounds())
self._last_active_language_id = lang_id
controller._lsp_did_change({
"uri": uri,
"language_id": lang_id,
"version": 1,
"text": text
})
iter = buffer.get_iter_at_mark( buffer.get_insert() )
line = iter.get_line()
column = iter.get_line_offset()
self._set_cache_refresh_trigger(
lang_id, fpath, line, column
)
def process_goto_definition(
self, lang_id: str, fpath: str, line: int, column: int
):
if lang_id not in self.clients:
logger.debug(f"No LSP client for '{lang_id}', skipping goto definition")
return
controller = self.clients[lang_id]
uri = f"file://{fpath}" if not fpath.startswith("file://") else fpath
self._last_active_language_id = lang_id
controller._lsp_definition({
"uri": uri,
"language_id": lang_id,
"version": 1,
"line": line,
"column": column
})
def process_completion_request(
self, lang_id: str, fpath: str, line: int, column: int
):
if lang_id not in self.clients:
logger.debug(f"No LSP client for '{lang_id}', skipping completion")
return
controller = self.clients[lang_id]
uri = f"file://{fpath}" if not fpath.startswith("file://") else fpath
self._last_active_language_id = lang_id
controller._lsp_completion({
"uri": uri,
"language_id": lang_id,
"version": 1,
"line": line,
"column": column
})
def _clear_delayed_cache_refresh_trigger(self):
if self._cache_refresh_timeout_id:
GLib.source_remove(self._cache_refresh_timeout_id)
def _set_cache_refresh_trigger(
self, lang_id: str, fpath: str, line: int, column: int
):
def trigger_cache_refresh(lang_id, fpath, line, column):
self._cache_refresh_timeout_id = None
self.process_completion_request(
lang_id, fpath, line, column
)
return False
self._cache_refresh_timeout_id = GLib.timeout_add(1500, trigger_cache_refresh, lang_id, fpath, line, column)

View File

@@ -0,0 +1,59 @@
# Python imports
# Lib imports
import gi
from gi.repository import GLib
# Application imports
from libs.event_factory import Code_Event_Types
class LSPServerEventsMixin:
def _handle_definition_response(self, result: dict or list):
if not result: return
self._prompt_goto_request(result[0]["uri"])
def _handle_completion_response(self, result: dict or list):
if not result: return
items = []
if isinstance(result, dict):
items = result.get("items", [])
elif isinstance(result, list):
items = result
self.matchers.clear()
for item in items:
label = item.get("label", "")
if not label: continue
text = item.get("insertText")
if not text and "textEdit" in item:
text = item["textEdit"].get("newText", "")
info = ""
if "detail" in item:
info = item["detail"]
elif "documentation" in item:
doc = item["documentation"]
if isinstance(doc, dict):
info = doc.get("value", "")
else:
info = str(doc)
self.matchers[label] = {
"label": label,
"text": text,
"info": info
}
self._prompt_completion_request()
def _prompt_completion_request(self):
raise NotImplementedError
def _prompt_goto_request(self, uri: str):
raise NotImplementedError

View File

@@ -0,0 +1,121 @@
# Python imports
# Lib imports
# Application imports
from libs.event_factory import Event_Factory, Code_Event_Types
from libs.dto.states import SourceViewStates
from plugins.plugin_types import PluginCode
from .lsp_manager import LSPManager
lsp_manager = LSPManager()
class Plugin(PluginCode):
def __init__(self):
super(Plugin, self).__init__()
def _controller_message(self, event: Code_Event_Types.CodeEvent):
...
def load(self):
window = self.request_ui_element("main-window")
lsp_manager.map_parent_resize_event(window)
event = Event_Factory.create_event("register_command",
command_name = "LSP Manager",
command = Handler,
binding_mode = "released",
binding = ["<Shift><Control>l", "<Control>g", "<Control>i"]
)
self.emit_to("source_views", event)
event = Event_Factory.create_event(
"register_provider",
provider_name = "LSP Completer",
provider = lsp_manager.provider,
language_ids = []
)
self.emit_to("completion", event)
event = Event_Factory.create_event(
"create_source_view",
state = SourceViewStates.INDEPENDENT
)
self.emit_to("source_views", event)
source_view = event.response
lsp_manager.load_lsp_servers_config()
lsp_manager.set_source_view(source_view)
lsp_manager.load_lsp_servers_config_placeholders()
lsp_manager.provider.response_cache._prompt_completion_request = \
self._prompt_completion_request
lsp_manager.provider.response_cache._prompt_goto_request = \
self._prompt_goto_request
def run(self):
...
def generate_plugin_element(self):
...
def _prompt_completion_request(self):
event = Event_Factory.create_event(
"get_active_view",
)
self.emit_to("source_views", event)
view = event.response
event = Event_Factory.create_event(
"request_completion",
view = view,
provider = lsp_manager.provider
)
self.emit_to("completion", event)
def _prompt_goto_request(self, uri: str):
event = Event_Factory.create_event(
"get_active_view",
)
self.emit_to("source_views", event)
view = event.response
view._on_uri_data_received( [uri] )
class Handler:
@staticmethod
def execute(
view: any,
*args,
**kwargs
):
logger.debug("Command: LSP Manager")
char_str = args[0]
if char_str in ["g", "i"]:
file = view.command.exec("get_current_file")
buffer = view.get_buffer()
iter = buffer.get_iter_at_mark( buffer.get_insert() )
line = iter.get_line()
column = iter.get_line_offset()
if char_str == "g":
lsp_manager.provider.response_cache.process_goto_definition(
file.ftype, file.fpath, line, column
)
return
if char_str == "i":
return
lsp_manager.hide() if lsp_manager.is_visible() else lsp_manager.show()

View File

@@ -0,0 +1,80 @@
# Python imports
# Lib imports
import gi
gi.require_version('GtkSource', '4')
from gi.repository import GObject
from gi.repository import GtkSource
# Application imports
from .provider_response_cache import ProviderResponseCache
class Provider(GObject.GObject, GtkSource.CompletionProvider):
"""
This code is an LSP code completion plugin for Newton.
# NOTE: Some code pulled/referenced from here --> https://github.com/isamert/gedi
"""
__gtype_name__ = 'LSPProvider'
def __init__(self):
super(Provider, self).__init__()
self.response_cache: ProviderResponseCache = ProviderResponseCache()
def pre_populate(self, context):
...
def do_get_name(self):
return "LSP Code Completion"
def do_match(self, context):
iter = self.response_cache.get_iter_correctly(context)
iter.backward_char()
ch = iter.get_char()
# NOTE: Look to re-add or apply supprting logic to use spaces
# As is it slows down the editor in certain contexts...
if not (ch in ('_', '.', ' ') or ch.isalnum()):
return False
buffer = iter.get_buffer()
if buffer.get_context_classes_at_iter(iter) != ['no-spell-check']:
return False
return True
def do_get_priority(self):
return 5
def do_activate_proposal(self, proposal, iter_):
buffer = iter_.get_buffer()
# Note: Flag mostly intended for SourceViewsMultiInsertState
# to insure marker processes inserted text correctly.
buffer.is_processing_completion = True
return False
def do_get_activation(self):
""" The context for when a provider will show results """
# return GtkSource.CompletionActivation.NONE
# return GtkSource.CompletionActivation.USER_REQUESTED
# return GtkSource.CompletionActivation.INTERACTIVE
return GtkSource.CompletionActivation.INTERACTIVE | GtkSource.CompletionActivation.USER_REQUESTED
def do_populate(self, context):
results = self.response_cache.filter_with_context(context)
proposals = []
for entry in results:
proposals.append(
self.response_cache.create_completion_item(
entry["label"],
entry["text"],
entry["info"]
)
)
context.add_proposals(self, proposals, True)

View File

@@ -0,0 +1,114 @@
# Python imports
from concurrent.futures import ThreadPoolExecutor
import asyncio
from asyncio import Queue
# Lib imports
import gi
gi.require_version('GtkSource', '4')
from gi.repository import GtkSource
# Application imports
from libs.dto.code.lsp.lsp_message_structs import LSPResponseTypes, LSPResponseRequest, LSPResponseNotification
from core.widgets.code.completion_providers.provider_response_cache_base import ProviderResponseCacheBase
from .controllers.lsp_controller import LSPController
from .mixins.lsp_client_events_mixin import LSPClientEventsMixin
from .mixins.lsp_server_events_mixin import LSPServerEventsMixin
class ProviderResponseCache(LSPClientEventsMixin, LSPServerEventsMixin, ProviderResponseCacheBase):
def __init__(self):
super(ProviderResponseCache, self).__init__()
self.executor = ThreadPoolExecutor(max_workers = 1)
self.matchers: dict = {}
self.clients: dict = {}
self._cache_refresh_timeout_id: int = None
self._last_active_language_id: str = None
def create_client(
self,
lang_id: str = "python",
workspace_uri: str = "",
init_opts: dict = {
}) -> bool:
if lang_id in self.clients: return False
address = "127.0.0.1"
port = 9999
uri = f"ws://{address}:{port}/{lang_id}"
controller = LSPController()
controller.handle_lsp_response = self.server_response
controller.set_language(lang_id)
controller.set_socket(uri)
controller.start_client()
if not controller.ws_client.wait_for_connection(timeout = 5.0):
logger.error(f"Failed to connect to LSP server for {lang_id}")
return False
self.clients[lang_id] = controller
controller.send_initialize_message(init_opts, "", f"file://{workspace_uri}")
return True
def close_client(self, lang_id: str) -> bool:
if lang_id not in self.clients: return False
controller = self.clients.pop(lang_id)
controller.stop_client()
return True
def server_response(self, lsp_response: LSPResponseTypes):
logger.debug(f"LSP Response: { lsp_response }")
if isinstance(lsp_response, LSPResponseRequest):
if not self._last_active_language_id in self.clients:
logger.debug(f"No LSP client for '{self._last_active_language_id}', skipping 'server_response'")
return
controller = self.clients[self._last_active_language_id]
event = controller.get_event_by_id(lsp_response.id)
match event:
case "textDocument/completion":
self._handle_completion_response(lsp_response.result)
case "textDocument/definition":
self._handle_definition_response(lsp_response.result)
case _:
...
elif isinstance(lsp_response, LSPResponseNotification):
match lsp_response.method:
case "textDocument/publishDiagnostics":
...
case _:
...
def filter(self, word: str) -> list[dict]:
return []
def filter_with_context(self, context: GtkSource.CompletionContext) -> list[dict]:
response = []
iter = self.get_iter_correctly(context)
iter.backward_char()
char_str = iter.get_char()
if char_str == "." or char_str == " ":
for label, item in self.matchers.items():
response.append(item)
return response
word = self.get_word(context).rstrip()
for label, item in self.matchers.items():
if label.startswith(word):
response.append(item)
return response

View File

@@ -0,0 +1,3 @@
"""
Pligin Module
"""

View File

@@ -0,0 +1,3 @@
"""
Pligin Package
"""

View File

@@ -0,0 +1,7 @@
{
"name": "Markdown Preview",
"author": "ITDominator",
"version": "0.0.1",
"support": "",
"requests": {}
}

View File

@@ -0,0 +1,48 @@
# Python Markdown
# A Python implementation of John Gruber's Markdown.
# - Documentation: https://python-markdown.github.io/
# - GitHub: https://github.com/Python-Markdown/markdown/
# - PyPI: https://pypi.org/project/Markdown/
# Started by Manfred Stienstra (http://www.dwerg.net/).
# Maintained for a few years by Yuri Takhteyev (http://www.freewisdom.org).
# Currently maintained by Waylan Limberg (https://github.com/waylan),
# Dmitry Shachnev (https://github.com/mitya57) and Isaac Muse (https://github.com/facelessuser).
# - Copyright 2007-2023 The Python Markdown Project (v. 1.7 and later)
# - Copyright 2004, 2005, 2006 Yuri Takhteyev (v. 0.2-1.6b)
# - Copyright 2004 Manfred Stienstra (the original version)
# License: BSD (see LICENSE.md for details).
"""
Python-Markdown provides two public functions ([`markdown.markdown`][] and [`markdown.markdownFromFile`][])
both of which wrap the public class [`markdown.Markdown`][]. All submodules support these public functions
and class and/or provide extension support.
Modules:
core: Core functionality.
preprocessors: Pre-processors.
blockparser: Core Markdown block parser.
blockprocessors: Block processors.
treeprocessors: Tree processors.
inlinepatterns: Inline patterns.
postprocessors: Post-processors.
serializers: Serializers.
util: Utility functions.
htmlparser: HTML parser.
test_tools: Testing utilities.
extensions: Markdown extensions.
"""
from __future__ import annotations
from .core import Markdown, markdown, markdownFromFile
from .__meta__ import __version__, __version_info__ # noqa
# For backward compatibility as some extensions expect it...
from .extensions import Extension # noqa
__all__ = ['Markdown', 'markdown', 'markdownFromFile']

View File

@@ -0,0 +1,151 @@
# Python Markdown
# A Python implementation of John Gruber's Markdown.
# Documentation: https://python-markdown.github.io/
# GitHub: https://github.com/Python-Markdown/markdown/
# PyPI: https://pypi.org/project/Markdown/
# Started by Manfred Stienstra (http://www.dwerg.net/).
# Maintained for a few years by Yuri Takhteyev (http://www.freewisdom.org).
# Currently maintained by Waylan Limberg (https://github.com/waylan),
# Dmitry Shachnev (https://github.com/mitya57) and Isaac Muse (https://github.com/facelessuser).
# Copyright 2007-2023 The Python Markdown Project (v. 1.7 and later)
# Copyright 2004, 2005, 2006 Yuri Takhteyev (v. 0.2-1.6b)
# Copyright 2004 Manfred Stienstra (the original version)
# License: BSD (see LICENSE.md for details).
from __future__ import annotations
import sys
import optparse
import codecs
import warnings
import markdown
try:
# We use `unsafe_load` because users may need to pass in actual Python
# objects. As this is only available from the CLI, the user has much
# worse problems if an attacker can use this as an attach vector.
from yaml import unsafe_load as yaml_load
except ImportError: # pragma: no cover
try:
# Fall back to PyYAML <5.1
from yaml import load as yaml_load
except ImportError:
# Fall back to JSON
from json import load as yaml_load
import logging
from logging import DEBUG, WARNING, CRITICAL
logger = logging.getLogger('MARKDOWN')
def parse_options(args=None, values=None):
"""
Define and parse `optparse` options for command-line usage.
"""
usage = """%prog [options] [INPUTFILE]
(STDIN is assumed if no INPUTFILE is given)"""
desc = "A Python implementation of John Gruber's Markdown. " \
"https://Python-Markdown.github.io/"
ver = "%%prog %s" % markdown.__version__
parser = optparse.OptionParser(usage=usage, description=desc, version=ver)
parser.add_option("-f", "--file", dest="filename", default=None,
help="Write output to OUTPUT_FILE. Defaults to STDOUT.",
metavar="OUTPUT_FILE")
parser.add_option("-e", "--encoding", dest="encoding",
help="Encoding for input and output files.",)
parser.add_option("-o", "--output_format", dest="output_format",
default='xhtml', metavar="OUTPUT_FORMAT",
help="Use output format 'xhtml' (default) or 'html'.")
parser.add_option("-n", "--no_lazy_ol", dest="lazy_ol",
action='store_false', default=True,
help="Observe number of first item of ordered lists.")
parser.add_option("-x", "--extension", action="append", dest="extensions",
help="Load extension EXTENSION.", metavar="EXTENSION")
parser.add_option("-c", "--extension_configs",
dest="configfile", default=None,
help="Read extension configurations from CONFIG_FILE. "
"CONFIG_FILE must be of JSON or YAML format. YAML "
"format requires that a python YAML library be "
"installed. The parsed JSON or YAML must result in a "
"python dictionary which would be accepted by the "
"'extension_configs' keyword on the markdown.Markdown "
"class. The extensions must also be loaded with the "
"`--extension` option.",
metavar="CONFIG_FILE")
parser.add_option("-q", "--quiet", default=CRITICAL,
action="store_const", const=CRITICAL+10, dest="verbose",
help="Suppress all warnings.")
parser.add_option("-v", "--verbose",
action="store_const", const=WARNING, dest="verbose",
help="Print all warnings.")
parser.add_option("--noisy",
action="store_const", const=DEBUG, dest="verbose",
help="Print debug messages.")
(options, args) = parser.parse_args(args, values)
if len(args) == 0:
input_file = None
else:
input_file = args[0]
if not options.extensions:
options.extensions = []
extension_configs = {}
if options.configfile:
with codecs.open(
options.configfile, mode="r", encoding=options.encoding
) as fp:
try:
extension_configs = yaml_load(fp)
except Exception as e:
message = "Failed parsing extension config file: %s" % \
options.configfile
e.args = (message,) + e.args[1:]
raise
opts = {
'input': input_file,
'output': options.filename,
'extensions': options.extensions,
'extension_configs': extension_configs,
'encoding': options.encoding,
'output_format': options.output_format,
'lazy_ol': options.lazy_ol
}
return opts, options.verbose
def run(): # pragma: no cover
"""Run Markdown from the command line."""
# Parse options and adjust logging level if necessary
options, logging_level = parse_options()
if not options:
sys.exit(2)
logger.setLevel(logging_level)
console_handler = logging.StreamHandler()
logger.addHandler(console_handler)
if logging_level <= WARNING:
# Ensure deprecation warnings get displayed
warnings.filterwarnings('default')
logging.captureWarnings(True)
warn_logger = logging.getLogger('py.warnings')
warn_logger.addHandler(console_handler)
# Run
markdown.markdownFromFile(**options)
if __name__ == '__main__': # pragma: no cover
# Support running module as a command line command.
# python -m markdown [options] [args]
run()

View File

@@ -0,0 +1,51 @@
# Python Markdown
# A Python implementation of John Gruber's Markdown.
# Documentation: https://python-markdown.github.io/
# GitHub: https://github.com/Python-Markdown/markdown/
# PyPI: https://pypi.org/project/Markdown/
# Started by Manfred Stienstra (http://www.dwerg.net/).
# Maintained for a few years by Yuri Takhteyev (http://www.freewisdom.org).
# Currently maintained by Waylan Limberg (https://github.com/waylan),
# Dmitry Shachnev (https://github.com/mitya57) and Isaac Muse (https://github.com/facelessuser).
# Copyright 2007-2023 The Python Markdown Project (v. 1.7 and later)
# Copyright 2004, 2005, 2006 Yuri Takhteyev (v. 0.2-1.6b)
# Copyright 2004 Manfred Stienstra (the original version)
# License: BSD (see LICENSE.md for details).
# __version_info__ format:
# (major, minor, patch, dev/alpha/beta/rc/final, #)
# (1, 1, 2, 'dev', 0) => "1.1.2.dev0"
# (1, 1, 2, 'alpha', 1) => "1.1.2a1"
# (1, 2, 0, 'beta', 2) => "1.2b2"
# (1, 2, 0, 'rc', 4) => "1.2rc4"
# (1, 2, 0, 'final', 0) => "1.2"
from __future__ import annotations
__version_info__ = (3, 5, 1, 'final', 0)
def _get_version(version_info):
" Returns a PEP 440-compliant version number from `version_info`. "
assert len(version_info) == 5
assert version_info[3] in ('dev', 'alpha', 'beta', 'rc', 'final')
parts = 2 if version_info[2] == 0 else 3
v = '.'.join(map(str, version_info[:parts]))
if version_info[3] == 'dev':
v += '.dev' + str(version_info[4])
elif version_info[3] != 'final':
mapping = {'alpha': 'a', 'beta': 'b', 'rc': 'rc'}
v += mapping[version_info[3]] + str(version_info[4])
return v
__version__ = _get_version(__version_info__)

View File

@@ -0,0 +1,160 @@
# Python Markdown
# A Python implementation of John Gruber's Markdown.
# Documentation: https://python-markdown.github.io/
# GitHub: https://github.com/Python-Markdown/markdown/
# PyPI: https://pypi.org/project/Markdown/
# Started by Manfred Stienstra (http://www.dwerg.net/).
# Maintained for a few years by Yuri Takhteyev (http://www.freewisdom.org).
# Currently maintained by Waylan Limberg (https://github.com/waylan),
# Dmitry Shachnev (https://github.com/mitya57) and Isaac Muse (https://github.com/facelessuser).
# Copyright 2007-2023 The Python Markdown Project (v. 1.7 and later)
# Copyright 2004, 2005, 2006 Yuri Takhteyev (v. 0.2-1.6b)
# Copyright 2004 Manfred Stienstra (the original version)
# License: BSD (see LICENSE.md for details).
"""
The block parser handles basic parsing of Markdown blocks. It doesn't concern
itself with inline elements such as `**bold**` or `*italics*`, but rather just
catches blocks, lists, quotes, etc.
The `BlockParser` is made up of a bunch of `BlockProcessors`, each handling a
different type of block. Extensions may add/replace/remove `BlockProcessors`
as they need to alter how Markdown blocks are parsed.
"""
from __future__ import annotations
import xml.etree.ElementTree as etree
from typing import TYPE_CHECKING, Iterable, Any
from . import util
if TYPE_CHECKING: # pragma: no cover
from markdown import Markdown
from .blockprocessors import BlockProcessor
class State(list):
""" Track the current and nested state of the parser.
This utility class is used to track the state of the `BlockParser` and
support multiple levels if nesting. It's just a simple API wrapped around
a list. Each time a state is set, that state is appended to the end of the
list. Each time a state is reset, that state is removed from the end of
the list.
Therefore, each time a state is set for a nested block, that state must be
reset when we back out of that level of nesting or the state could be
corrupted.
While all the methods of a list object are available, only the three
defined below need be used.
"""
def set(self, state: Any):
""" Set a new state. """
self.append(state)
def reset(self) -> None:
""" Step back one step in nested state. """
self.pop()
def isstate(self, state: Any) -> bool:
""" Test that top (current) level is of given state. """
if len(self):
return self[-1] == state
else:
return False
class BlockParser:
""" Parse Markdown blocks into an `ElementTree` object.
A wrapper class that stitches the various `BlockProcessors` together,
looping through them and creating an `ElementTree` object.
"""
def __init__(self, md: Markdown):
""" Initialize the block parser.
Arguments:
md: A Markdown instance.
Attributes:
BlockParser.md (Markdown): A Markdown instance.
BlockParser.state (State): Tracks the nesting level of current location in document being parsed.
BlockParser.blockprocessors (util.Registry): A collection of
[`blockprocessors`][markdown.blockprocessors].
"""
self.blockprocessors: util.Registry[BlockProcessor] = util.Registry()
self.state = State()
self.md = md
def parseDocument(self, lines: Iterable[str]) -> etree.ElementTree:
""" Parse a Markdown document into an `ElementTree`.
Given a list of lines, an `ElementTree` object (not just a parent
`Element`) is created and the root element is passed to the parser
as the parent. The `ElementTree` object is returned.
This should only be called on an entire document, not pieces.
Arguments:
lines: A list of lines (strings).
Returns:
An element tree.
"""
# Create an `ElementTree` from the lines
self.root = etree.Element(self.md.doc_tag)
self.parseChunk(self.root, '\n'.join(lines))
return etree.ElementTree(self.root)
def parseChunk(self, parent: etree.Element, text: str) -> None:
""" Parse a chunk of Markdown text and attach to given `etree` node.
While the `text` argument is generally assumed to contain multiple
blocks which will be split on blank lines, it could contain only one
block. Generally, this method would be called by extensions when
block parsing is required.
The `parent` `etree` Element passed in is altered in place.
Nothing is returned.
Arguments:
parent: The parent element.
text: The text to parse.
"""
self.parseBlocks(parent, text.split('\n\n'))
def parseBlocks(self, parent: etree.Element, blocks: list[str]) -> None:
""" Process blocks of Markdown text and attach to given `etree` node.
Given a list of `blocks`, each `blockprocessor` is stepped through
until there are no blocks left. While an extension could potentially
call this method directly, it's generally expected to be used
internally.
This is a public method as an extension may need to add/alter
additional `BlockProcessors` which call this method to recursively
parse a nested block.
Arguments:
parent: The parent element.
blocks: The blocks of text to parse.
"""
while blocks:
for processor in self.blockprocessors:
if processor.test(parent, blocks[0]):
if processor.run(parent, blocks) is not False:
# run returns True or None
break

View File

@@ -0,0 +1,636 @@
# Python Markdown
# A Python implementation of John Gruber's Markdown.
# Documentation: https://python-markdown.github.io/
# GitHub: https://github.com/Python-Markdown/markdown/
# PyPI: https://pypi.org/project/Markdown/
# Started by Manfred Stienstra (http://www.dwerg.net/).
# Maintained for a few years by Yuri Takhteyev (http://www.freewisdom.org).
# Currently maintained by Waylan Limberg (https://github.com/waylan),
# Dmitry Shachnev (https://github.com/mitya57) and Isaac Muse (https://github.com/facelessuser).
# Copyright 2007-2023 The Python Markdown Project (v. 1.7 and later)
# Copyright 2004, 2005, 2006 Yuri Takhteyev (v. 0.2-1.6b)
# Copyright 2004 Manfred Stienstra (the original version)
# License: BSD (see LICENSE.md for details).
"""
A block processor parses blocks of text and adds new elements to the ElementTree. Blocks of text,
separated from other text by blank lines, may have a different syntax and produce a differently
structured tree than other Markdown. Block processors excel at handling code formatting, equation
layouts, tables, etc.
"""
from __future__ import annotations
import logging
import re
import xml.etree.ElementTree as etree
from typing import TYPE_CHECKING, Any
from . import util
from .blockparser import BlockParser
if TYPE_CHECKING: # pragma: no cover
from markdown import Markdown
logger = logging.getLogger('MARKDOWN')
def build_block_parser(md: Markdown, **kwargs: Any) -> BlockParser:
""" Build the default block parser used by Markdown. """
parser = BlockParser(md)
parser.blockprocessors.register(EmptyBlockProcessor(parser), 'empty', 100)
parser.blockprocessors.register(ListIndentProcessor(parser), 'indent', 90)
parser.blockprocessors.register(CodeBlockProcessor(parser), 'code', 80)
parser.blockprocessors.register(HashHeaderProcessor(parser), 'hashheader', 70)
parser.blockprocessors.register(SetextHeaderProcessor(parser), 'setextheader', 60)
parser.blockprocessors.register(HRProcessor(parser), 'hr', 50)
parser.blockprocessors.register(OListProcessor(parser), 'olist', 40)
parser.blockprocessors.register(UListProcessor(parser), 'ulist', 30)
parser.blockprocessors.register(BlockQuoteProcessor(parser), 'quote', 20)
parser.blockprocessors.register(ReferenceProcessor(parser), 'reference', 15)
parser.blockprocessors.register(ParagraphProcessor(parser), 'paragraph', 10)
return parser
class BlockProcessor:
""" Base class for block processors.
Each subclass will provide the methods below to work with the source and
tree. Each processor will need to define it's own `test` and `run`
methods. The `test` method should return True or False, to indicate
whether the current block should be processed by this processor. If the
test passes, the parser will call the processors `run` method.
Attributes:
BlockProcessor.parser (BlockParser): The `BlockParser` instance this is attached to.
BlockProcessor.tab_length (int): The tab length set on the `Markdown` instance.
"""
def __init__(self, parser: BlockParser):
self.parser = parser
self.tab_length = parser.md.tab_length
def lastChild(self, parent: etree.Element) -> etree.Element | None:
""" Return the last child of an `etree` element. """
if len(parent):
return parent[-1]
else:
return None
def detab(self, text: str, length: int | None = None) -> tuple[str, str]:
""" Remove a tab from the front of each line of the given text. """
if length is None:
length = self.tab_length
newtext = []
lines = text.split('\n')
for line in lines:
if line.startswith(' ' * length):
newtext.append(line[length:])
elif not line.strip():
newtext.append('')
else:
break
return '\n'.join(newtext), '\n'.join(lines[len(newtext):])
def looseDetab(self, text: str, level: int = 1) -> str:
""" Remove a tab from front of lines but allowing dedented lines. """
lines = text.split('\n')
for i in range(len(lines)):
if lines[i].startswith(' '*self.tab_length*level):
lines[i] = lines[i][self.tab_length*level:]
return '\n'.join(lines)
def test(self, parent: etree.Element, block: str) -> bool:
""" Test for block type. Must be overridden by subclasses.
As the parser loops through processors, it will call the `test`
method on each to determine if the given block of text is of that
type. This method must return a boolean `True` or `False`. The
actual method of testing is left to the needs of that particular
block type. It could be as simple as `block.startswith(some_string)`
or a complex regular expression. As the block type may be different
depending on the parent of the block (i.e. inside a list), the parent
`etree` element is also provided and may be used as part of the test.
Keyword arguments:
parent: An `etree` element which will be the parent of the block.
block: A block of text from the source which has been split at blank lines.
"""
pass # pragma: no cover
def run(self, parent: etree.Element, blocks: list[str]) -> bool | None:
""" Run processor. Must be overridden by subclasses.
When the parser determines the appropriate type of a block, the parser
will call the corresponding processor's `run` method. This method
should parse the individual lines of the block and append them to
the `etree`.
Note that both the `parent` and `etree` keywords are pointers
to instances of the objects which should be edited in place. Each
processor must make changes to the existing objects as there is no
mechanism to return new/different objects to replace them.
This means that this method should be adding `SubElements` or adding text
to the parent, and should remove (`pop`) or add (`insert`) items to
the list of blocks.
If `False` is returned, this will have the same effect as returning `False`
from the `test` method.
Keyword arguments:
parent: An `etree` element which is the parent of the current block.
blocks: A list of all remaining blocks of the document.
"""
pass # pragma: no cover
class ListIndentProcessor(BlockProcessor):
""" Process children of list items.
Example
* a list item
process this part
or this part
"""
ITEM_TYPES = ['li']
""" List of tags used for list items. """
LIST_TYPES = ['ul', 'ol']
""" Types of lists this processor can operate on. """
def __init__(self, *args):
super().__init__(*args)
self.INDENT_RE = re.compile(r'^(([ ]{%s})+)' % self.tab_length)
def test(self, parent, block):
return block.startswith(' '*self.tab_length) and \
not self.parser.state.isstate('detabbed') and \
(parent.tag in self.ITEM_TYPES or
(len(parent) and parent[-1] is not None and
(parent[-1].tag in self.LIST_TYPES)))
def run(self, parent, blocks):
block = blocks.pop(0)
level, sibling = self.get_level(parent, block)
block = self.looseDetab(block, level)
self.parser.state.set('detabbed')
if parent.tag in self.ITEM_TYPES:
# It's possible that this parent has a `ul` or `ol` child list
# with a member. If that is the case, then that should be the
# parent. This is intended to catch the edge case of an indented
# list whose first member was parsed previous to this point
# see `OListProcessor`
if len(parent) and parent[-1].tag in self.LIST_TYPES:
self.parser.parseBlocks(parent[-1], [block])
else:
# The parent is already a `li`. Just parse the child block.
self.parser.parseBlocks(parent, [block])
elif sibling.tag in self.ITEM_TYPES:
# The sibling is a `li`. Use it as parent.
self.parser.parseBlocks(sibling, [block])
elif len(sibling) and sibling[-1].tag in self.ITEM_TYPES:
# The parent is a list (`ol` or `ul`) which has children.
# Assume the last child `li` is the parent of this block.
if sibling[-1].text:
# If the parent `li` has text, that text needs to be moved to a `p`
# The `p` must be 'inserted' at beginning of list in the event
# that other children already exist i.e.; a nested sub-list.
p = etree.Element('p')
p.text = sibling[-1].text
sibling[-1].text = ''
sibling[-1].insert(0, p)
self.parser.parseChunk(sibling[-1], block)
else:
self.create_item(sibling, block)
self.parser.state.reset()
def create_item(self, parent: etree.Element, block: str) -> None:
""" Create a new `li` and parse the block with it as the parent. """
li = etree.SubElement(parent, 'li')
self.parser.parseBlocks(li, [block])
def get_level(self, parent: etree.Element, block: str) -> tuple[int, etree.Element]:
""" Get level of indentation based on list level. """
# Get indent level
m = self.INDENT_RE.match(block)
if m:
indent_level = len(m.group(1))/self.tab_length
else:
indent_level = 0
if self.parser.state.isstate('list'):
# We're in a tight-list - so we already are at correct parent.
level = 1
else:
# We're in a loose-list - so we need to find parent.
level = 0
# Step through children of tree to find matching indent level.
while indent_level > level:
child = self.lastChild(parent)
if (child is not None and
(child.tag in self.LIST_TYPES or child.tag in self.ITEM_TYPES)):
if child.tag in self.LIST_TYPES:
level += 1
parent = child
else:
# No more child levels. If we're short of `indent_level`,
# we have a code block. So we stop here.
break
return level, parent
class CodeBlockProcessor(BlockProcessor):
""" Process code blocks. """
def test(self, parent, block):
return block.startswith(' '*self.tab_length)
def run(self, parent, blocks):
sibling = self.lastChild(parent)
block = blocks.pop(0)
theRest = ''
if (sibling is not None and sibling.tag == "pre" and
len(sibling) and sibling[0].tag == "code"):
# The previous block was a code block. As blank lines do not start
# new code blocks, append this block to the previous, adding back
# line breaks removed from the split into a list.
code = sibling[0]
block, theRest = self.detab(block)
code.text = util.AtomicString(
'{}\n{}\n'.format(code.text, util.code_escape(block.rstrip()))
)
else:
# This is a new code block. Create the elements and insert text.
pre = etree.SubElement(parent, 'pre')
code = etree.SubElement(pre, 'code')
block, theRest = self.detab(block)
code.text = util.AtomicString('%s\n' % util.code_escape(block.rstrip()))
if theRest:
# This block contained unindented line(s) after the first indented
# line. Insert these lines as the first block of the master blocks
# list for future processing.
blocks.insert(0, theRest)
class BlockQuoteProcessor(BlockProcessor):
""" Process blockquotes. """
RE = re.compile(r'(^|\n)[ ]{0,3}>[ ]?(.*)')
def test(self, parent, block):
return bool(self.RE.search(block)) and not util.nearing_recursion_limit()
def run(self, parent, blocks):
block = blocks.pop(0)
m = self.RE.search(block)
if m:
before = block[:m.start()] # Lines before blockquote
# Pass lines before blockquote in recursively for parsing first.
self.parser.parseBlocks(parent, [before])
# Remove `> ` from beginning of each line.
block = '\n'.join(
[self.clean(line) for line in block[m.start():].split('\n')]
)
sibling = self.lastChild(parent)
if sibling is not None and sibling.tag == "blockquote":
# Previous block was a blockquote so set that as this blocks parent
quote = sibling
else:
# This is a new blockquote. Create a new parent element.
quote = etree.SubElement(parent, 'blockquote')
# Recursively parse block with blockquote as parent.
# change parser state so blockquotes embedded in lists use `p` tags
self.parser.state.set('blockquote')
self.parser.parseChunk(quote, block)
self.parser.state.reset()
def clean(self, line: str) -> str:
""" Remove `>` from beginning of a line. """
m = self.RE.match(line)
if line.strip() == ">":
return ""
elif m:
return m.group(2)
else:
return line
class OListProcessor(BlockProcessor):
""" Process ordered list blocks. """
TAG: str = 'ol'
""" The tag used for the the wrapping element. """
STARTSWITH: str = '1'
"""
The integer (as a string ) with which the list starts. For example, if a list is initialized as
`3. Item`, then the `ol` tag will be assigned an HTML attribute of `starts="3"`. Default: `"1"`.
"""
LAZY_OL: bool = True
""" Ignore `STARTSWITH` if `True`. """
SIBLING_TAGS: list[str] = ['ol', 'ul']
"""
Markdown does not require the type of a new list item match the previous list item type.
This is the list of types which can be mixed.
"""
def __init__(self, parser: BlockParser):
super().__init__(parser)
# Detect an item (`1. item`). `group(1)` contains contents of item.
self.RE = re.compile(r'^[ ]{0,%d}\d+\.[ ]+(.*)' % (self.tab_length - 1))
# Detect items on secondary lines. they can be of either list type.
self.CHILD_RE = re.compile(r'^[ ]{0,%d}((\d+\.)|[*+-])[ ]+(.*)' %
(self.tab_length - 1))
# Detect indented (nested) items of either type
self.INDENT_RE = re.compile(r'^[ ]{%d,%d}((\d+\.)|[*+-])[ ]+.*' %
(self.tab_length, self.tab_length * 2 - 1))
def test(self, parent, block):
return bool(self.RE.match(block))
def run(self, parent, blocks):
# Check for multiple items in one block.
items = self.get_items(blocks.pop(0))
sibling = self.lastChild(parent)
if sibling is not None and sibling.tag in self.SIBLING_TAGS:
# Previous block was a list item, so set that as parent
lst = sibling
# make sure previous item is in a `p` - if the item has text,
# then it isn't in a `p`
if lst[-1].text:
# since it's possible there are other children for this
# sibling, we can't just `SubElement` the `p`, we need to
# insert it as the first item.
p = etree.Element('p')
p.text = lst[-1].text
lst[-1].text = ''
lst[-1].insert(0, p)
# if the last item has a tail, then the tail needs to be put in a `p`
# likely only when a header is not followed by a blank line
lch = self.lastChild(lst[-1])
if lch is not None and lch.tail:
p = etree.SubElement(lst[-1], 'p')
p.text = lch.tail.lstrip()
lch.tail = ''
# parse first block differently as it gets wrapped in a `p`.
li = etree.SubElement(lst, 'li')
self.parser.state.set('looselist')
firstitem = items.pop(0)
self.parser.parseBlocks(li, [firstitem])
self.parser.state.reset()
elif parent.tag in ['ol', 'ul']:
# this catches the edge case of a multi-item indented list whose
# first item is in a blank parent-list item:
# * * subitem1
# * subitem2
# see also `ListIndentProcessor`
lst = parent
else:
# This is a new list so create parent with appropriate tag.
lst = etree.SubElement(parent, self.TAG)
# Check if a custom start integer is set
if not self.LAZY_OL and self.STARTSWITH != '1':
lst.attrib['start'] = self.STARTSWITH
self.parser.state.set('list')
# Loop through items in block, recursively parsing each with the
# appropriate parent.
for item in items:
if item.startswith(' '*self.tab_length):
# Item is indented. Parse with last item as parent
self.parser.parseBlocks(lst[-1], [item])
else:
# New item. Create `li` and parse with it as parent
li = etree.SubElement(lst, 'li')
self.parser.parseBlocks(li, [item])
self.parser.state.reset()
def get_items(self, block: str) -> list[str]:
""" Break a block into list items. """
items = []
for line in block.split('\n'):
m = self.CHILD_RE.match(line)
if m:
# This is a new list item
# Check first item for the start index
if not items and self.TAG == 'ol':
# Detect the integer value of first list item
INTEGER_RE = re.compile(r'(\d+)')
self.STARTSWITH = INTEGER_RE.match(m.group(1)).group()
# Append to the list
items.append(m.group(3))
elif self.INDENT_RE.match(line):
# This is an indented (possibly nested) item.
if items[-1].startswith(' '*self.tab_length):
# Previous item was indented. Append to that item.
items[-1] = '{}\n{}'.format(items[-1], line)
else:
items.append(line)
else:
# This is another line of previous item. Append to that item.
items[-1] = '{}\n{}'.format(items[-1], line)
return items
class UListProcessor(OListProcessor):
""" Process unordered list blocks. """
TAG: str = 'ul'
""" The tag used for the the wrapping element. """
def __init__(self, parser: BlockParser):
super().__init__(parser)
# Detect an item (`1. item`). `group(1)` contains contents of item.
self.RE = re.compile(r'^[ ]{0,%d}[*+-][ ]+(.*)' % (self.tab_length - 1))
class HashHeaderProcessor(BlockProcessor):
""" Process Hash Headers. """
# Detect a header at start of any line in block
RE = re.compile(r'(?:^|\n)(?P<level>#{1,6})(?P<header>(?:\\.|[^\\])*?)#*(?:\n|$)')
def test(self, parent, block):
return bool(self.RE.search(block))
def run(self, parent, blocks):
block = blocks.pop(0)
m = self.RE.search(block)
if m:
before = block[:m.start()] # All lines before header
after = block[m.end():] # All lines after header
if before:
# As the header was not the first line of the block and the
# lines before the header must be parsed first,
# recursively parse this lines as a block.
self.parser.parseBlocks(parent, [before])
# Create header using named groups from RE
h = etree.SubElement(parent, 'h%d' % len(m.group('level')))
h.text = m.group('header').strip()
if after:
# Insert remaining lines as first block for future parsing.
blocks.insert(0, after)
else: # pragma: no cover
# This should never happen, but just in case...
logger.warn("We've got a problem header: %r" % block)
class SetextHeaderProcessor(BlockProcessor):
""" Process Setext-style Headers. """
# Detect Setext-style header. Must be first 2 lines of block.
RE = re.compile(r'^.*?\n[=-]+[ ]*(\n|$)', re.MULTILINE)
def test(self, parent, block):
return bool(self.RE.match(block))
def run(self, parent, blocks):
lines = blocks.pop(0).split('\n')
# Determine level. `=` is 1 and `-` is 2.
if lines[1].startswith('='):
level = 1
else:
level = 2
h = etree.SubElement(parent, 'h%d' % level)
h.text = lines[0].strip()
if len(lines) > 2:
# Block contains additional lines. Add to master blocks for later.
blocks.insert(0, '\n'.join(lines[2:]))
class HRProcessor(BlockProcessor):
""" Process Horizontal Rules. """
# Python's `re` module doesn't officially support atomic grouping. However you can fake it.
# See https://stackoverflow.com/a/13577411/866026
RE = r'^[ ]{0,3}(?=(?P<atomicgroup>(-+[ ]{0,2}){3,}|(_+[ ]{0,2}){3,}|(\*+[ ]{0,2}){3,}))(?P=atomicgroup)[ ]*$'
# Detect hr on any line of a block.
SEARCH_RE = re.compile(RE, re.MULTILINE)
def test(self, parent, block):
m = self.SEARCH_RE.search(block)
if m:
# Save match object on class instance so we can use it later.
self.match = m
return True
return False
def run(self, parent, blocks):
block = blocks.pop(0)
match = self.match
# Check for lines in block before `hr`.
prelines = block[:match.start()].rstrip('\n')
if prelines:
# Recursively parse lines before `hr` so they get parsed first.
self.parser.parseBlocks(parent, [prelines])
# create hr
etree.SubElement(parent, 'hr')
# check for lines in block after `hr`.
postlines = block[match.end():].lstrip('\n')
if postlines:
# Add lines after `hr` to master blocks for later parsing.
blocks.insert(0, postlines)
class EmptyBlockProcessor(BlockProcessor):
""" Process blocks that are empty or start with an empty line. """
def test(self, parent, block):
return not block or block.startswith('\n')
def run(self, parent, blocks):
block = blocks.pop(0)
filler = '\n\n'
if block:
# Starts with empty line
# Only replace a single line.
filler = '\n'
# Save the rest for later.
theRest = block[1:]
if theRest:
# Add remaining lines to master blocks for later.
blocks.insert(0, theRest)
sibling = self.lastChild(parent)
if (sibling is not None and sibling.tag == 'pre' and
len(sibling) and sibling[0].tag == 'code'):
# Last block is a code block. Append to preserve whitespace.
sibling[0].text = util.AtomicString(
'{}{}'.format(sibling[0].text, filler)
)
class ReferenceProcessor(BlockProcessor):
""" Process link references. """
RE = re.compile(
r'^[ ]{0,3}\[([^\[\]]*)\]:[ ]*\n?[ ]*([^\s]+)[ ]*(?:\n[ ]*)?((["\'])(.*)\4[ ]*|\((.*)\)[ ]*)?$', re.MULTILINE
)
def test(self, parent, block):
return True
def run(self, parent, blocks):
block = blocks.pop(0)
m = self.RE.search(block)
if m:
id = m.group(1).strip().lower()
link = m.group(2).lstrip('<').rstrip('>')
title = m.group(5) or m.group(6)
self.parser.md.references[id] = (link, title)
if block[m.end():].strip():
# Add any content after match back to blocks as separate block
blocks.insert(0, block[m.end():].lstrip('\n'))
if block[:m.start()].strip():
# Add any content before match back to blocks as separate block
blocks.insert(0, block[:m.start()].rstrip('\n'))
return True
# No match. Restore block.
blocks.insert(0, block)
return False
class ParagraphProcessor(BlockProcessor):
""" Process Paragraph blocks. """
def test(self, parent, block):
return True
def run(self, parent, blocks):
block = blocks.pop(0)
if block.strip():
# Not a blank block. Add to parent, otherwise throw it away.
if self.parser.state.isstate('list'):
# The parent is a tight-list.
#
# Check for any children. This will likely only happen in a
# tight-list when a header isn't followed by a blank line.
# For example:
#
# * # Header
# Line 2 of list item - not part of header.
sibling = self.lastChild(parent)
if sibling is not None:
# Insert after sibling.
if sibling.tail:
sibling.tail = '{}\n{}'.format(sibling.tail, block)
else:
sibling.tail = '\n%s' % block
else:
# Append to parent.text
if parent.text:
parent.text = '{}\n{}'.format(parent.text, block)
else:
parent.text = block.lstrip()
else:
# Create a regular paragraph
p = etree.SubElement(parent, 'p')
p.text = block.lstrip()

View File

@@ -0,0 +1,510 @@
# Python Markdown
# A Python implementation of John Gruber's Markdown.
# Documentation: https://python-markdown.github.io/
# GitHub: https://github.com/Python-Markdown/markdown/
# PyPI: https://pypi.org/project/Markdown/
# Started by Manfred Stienstra (http://www.dwerg.net/).
# Maintained for a few years by Yuri Takhteyev (http://www.freewisdom.org).
# Currently maintained by Waylan Limberg (https://github.com/waylan),
# Dmitry Shachnev (https://github.com/mitya57) and Isaac Muse (https://github.com/facelessuser).
# Copyright 2007-2023 The Python Markdown Project (v. 1.7 and later)
# Copyright 2004, 2005, 2006 Yuri Takhteyev (v. 0.2-1.6b)
# Copyright 2004 Manfred Stienstra (the original version)
# License: BSD (see LICENSE.md for details).
from __future__ import annotations
import codecs
import sys
import logging
import importlib
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Mapping, Sequence, TextIO
from . import util
from .preprocessors import build_preprocessors
from .blockprocessors import build_block_parser
from .treeprocessors import build_treeprocessors
from .inlinepatterns import build_inlinepatterns
from .postprocessors import build_postprocessors
from .extensions import Extension
from .serializers import to_html_string, to_xhtml_string
from .util import BLOCK_LEVEL_ELEMENTS
if TYPE_CHECKING: # pragma: no cover
from xml.etree.ElementTree import Element
__all__ = ['Markdown', 'markdown', 'markdownFromFile']
logger = logging.getLogger('MARKDOWN')
class Markdown:
"""
A parser which converts Markdown to HTML.
Attributes:
Markdown.tab_length (int): The number of spaces which correspond to a single tab. Default: `4`.
Markdown.ESCAPED_CHARS (list[str]): List of characters which get the backslash escape treatment.
Markdown.block_level_elements (list[str]): List of HTML tags which get treated as block-level elements.
See [`markdown.util.BLOCK_LEVEL_ELEMENTS`][] for the full list of elements.
Markdown.registeredExtensions (list[Extension]): List of extensions which have called
[`registerExtension`][markdown.Markdown.registerExtension] during setup.
Markdown.doc_tag (str): Element used to wrap document. Default: `div`.
Markdown.stripTopLevelTags (bool): Indicates whether the `doc_tag` should be removed. Default: 'True'.
Markdown.references (dict[str, tuple[str, str]]): A mapping of link references found in a parsed document
where the key is the reference name and the value is a tuple of the URL and title.
Markdown.htmlStash (util.HtmlStash): The instance of the `HtmlStash` used by an instance of this class.
Markdown.output_formats (dict[str, Callable[xml.etree.ElementTree.Element]]): A mapping of known output
formats by name and their respective serializers. Each serializer must be a callable which accepts an
[`Element`][xml.etree.ElementTree.Element] and returns a `str`.
Markdown.output_format (str): The output format set by
[`set_output_format`][markdown.Markdown.set_output_format].
Markdown.serializer (Callable[xml.etree.ElementTree.Element]): The serializer set by
[`set_output_format`][markdown.Markdown.set_output_format].
Markdown.preprocessors (util.Registry): A collection of [`preprocessors`][markdown.preprocessors].
Markdown.parser (blockparser.BlockParser): A collection of [`blockprocessors`][markdown.blockprocessors].
Markdown.inlinePatterns (util.Registry): A collection of [`inlinepatterns`][markdown.inlinepatterns].
Markdown.treeprocessors (util.Registry): A collection of [`treeprocessors`][markdown.treeprocessors].
Markdown.postprocessors (util.Registry): A collection of [`postprocessors`][markdown.postprocessors].
"""
doc_tag = "div" # Element used to wrap document - later removed
output_formats: ClassVar[dict[str, Callable[[Element], str]]] = {
'html': to_html_string,
'xhtml': to_xhtml_string,
}
"""
A mapping of known output formats by name and their respective serializers. Each serializer must be a
callable which accepts an [`Element`][xml.etree.ElementTree.Element] and returns a `str`.
"""
def __init__(self, **kwargs):
"""
Creates a new Markdown instance.
Keyword Arguments:
extensions (list[Extension | str]): A list of extensions.
If an item is an instance of a subclass of [`markdown.extensions.Extension`][],
the instance will be used as-is. If an item is of type `str`, it is passed
to [`build_extension`][markdown.Markdown.build_extension] with its corresponding
`extension_configs` and the returned instance of [`markdown.extensions.Extension`][]
is used.
extension_configs (dict[str, dict[str, Any]]): Configuration settings for extensions.
output_format (str): Format of output. Supported formats are:
* `xhtml`: Outputs XHTML style tags. Default.
* `html`: Outputs HTML style tags.
tab_length (int): Length of tabs in the source. Default: `4`
"""
self.tab_length: int = kwargs.get('tab_length', 4)
self.ESCAPED_CHARS: list[str] = [
'\\', '`', '*', '_', '{', '}', '[', ']', '(', ')', '>', '#', '+', '-', '.', '!'
]
""" List of characters which get the backslash escape treatment. """
self.block_level_elements: list[str] = BLOCK_LEVEL_ELEMENTS.copy()
self.registeredExtensions: list[Extension] = []
self.docType = "" # TODO: Maybe delete this. It does not appear to be used anymore.
self.stripTopLevelTags: bool = True
self.build_parser()
self.references: dict[str, tuple[str, str]] = {}
self.htmlStash: util.HtmlStash = util.HtmlStash()
self.registerExtensions(extensions=kwargs.get('extensions', []),
configs=kwargs.get('extension_configs', {}))
self.set_output_format(kwargs.get('output_format', 'xhtml'))
self.reset()
def build_parser(self) -> Markdown:
"""
Build the parser from the various parts.
Assigns a value to each of the following attributes on the class instance:
* **`Markdown.preprocessors`** ([`Registry`][markdown.util.Registry]) -- A collection of
[`preprocessors`][markdown.preprocessors].
* **`Markdown.parser`** ([`BlockParser`][markdown.blockparser.BlockParser]) -- A collection of
[`blockprocessors`][markdown.blockprocessors].
* **`Markdown.inlinePatterns`** ([`Registry`][markdown.util.Registry]) -- A collection of
[`inlinepatterns`][markdown.inlinepatterns].
* **`Markdown.treeprocessors`** ([`Registry`][markdown.util.Registry]) -- A collection of
[`treeprocessors`][markdown.treeprocessors].
* **`Markdown.postprocessors`** ([`Registry`][markdown.util.Registry]) -- A collection of
[`postprocessors`][markdown.postprocessors].
This method could be redefined in a subclass to build a custom parser which is made up of a different
combination of processors and patterns.
"""
self.preprocessors = build_preprocessors(self)
self.parser = build_block_parser(self)
self.inlinePatterns = build_inlinepatterns(self)
self.treeprocessors = build_treeprocessors(self)
self.postprocessors = build_postprocessors(self)
return self
def registerExtensions(
self,
extensions: Sequence[Extension | str],
configs: Mapping[str, Mapping[str, Any]]
) -> Markdown:
"""
Load a list of extensions into an instance of the `Markdown` class.
Arguments:
extensions (list[Extension | str]): A list of extensions.
If an item is an instance of a subclass of [`markdown.extensions.Extension`][],
the instance will be used as-is. If an item is of type `str`, it is passed
to [`build_extension`][markdown.Markdown.build_extension] with its corresponding `configs` and the
returned instance of [`markdown.extensions.Extension`][] is used.
configs (dict[str, dict[str, Any]]): Configuration settings for extensions.
"""
for ext in extensions:
if isinstance(ext, str):
ext = self.build_extension(ext, configs.get(ext, {}))
if isinstance(ext, Extension):
ext.extendMarkdown(self)
logger.debug(
'Successfully loaded extension "%s.%s".'
% (ext.__class__.__module__, ext.__class__.__name__)
)
elif ext is not None:
raise TypeError(
'Extension "{}.{}" must be of type: "{}.{}"'.format(
ext.__class__.__module__, ext.__class__.__name__,
Extension.__module__, Extension.__name__
)
)
return self
def build_extension(self, ext_name: str, configs: Mapping[str, Any]) -> Extension:
"""
Build extension from a string name, then return an instance using the given `configs`.
Arguments:
ext_name: Name of extension as a string.
configs: Configuration settings for extension.
Returns:
An instance of the extension with the given configuration settings.
First attempt to load an entry point. The string name must be registered as an entry point in the
`markdown.extensions` group which points to a subclass of the [`markdown.extensions.Extension`][] class.
If multiple distributions have registered the same name, the first one found is returned.
If no entry point is found, assume dot notation (`path.to.module:ClassName`). Load the specified class and
return an instance. If no class is specified, import the module and call a `makeExtension` function and return
the [`markdown.extensions.Extension`][] instance returned by that function.
"""
configs = dict(configs)
entry_points = [ep for ep in util.get_installed_extensions() if ep.name == ext_name]
if entry_points:
ext = entry_points[0].load()
return ext(**configs)
# Get class name (if provided): `path.to.module:ClassName`
ext_name, class_name = ext_name.split(':', 1) if ':' in ext_name else (ext_name, '')
try:
module = importlib.import_module(ext_name)
logger.debug(
'Successfully imported extension module "%s".' % ext_name
)
except ImportError as e:
message = 'Failed loading extension "%s".' % ext_name
e.args = (message,) + e.args[1:]
raise
if class_name:
# Load given class name from module.
return getattr(module, class_name)(**configs)
else:
# Expect `makeExtension()` function to return a class.
try:
return module.makeExtension(**configs)
except AttributeError as e:
message = e.args[0]
message = "Failed to initiate extension " \
"'%s': %s" % (ext_name, message)
e.args = (message,) + e.args[1:]
raise
def registerExtension(self, extension: Extension) -> Markdown:
"""
Register an extension as having a resettable state.
Arguments:
extension: An instance of the extension to register.
This should get called once by an extension during setup. A "registered" extension's
`reset` method is called by [`Markdown.reset()`][markdown.Markdown.reset]. Not all extensions have or need a
resettable state, and so it should not be assumed that all extensions are "registered."
"""
self.registeredExtensions.append(extension)
return self
def reset(self) -> Markdown:
"""
Resets all state variables to prepare the parser instance for new input.
Called once upon creation of a class instance. Should be called manually between calls
to [`Markdown.convert`][markdown.Markdown.convert].
"""
self.htmlStash.reset()
self.references.clear()
for extension in self.registeredExtensions:
if hasattr(extension, 'reset'):
extension.reset()
return self
def set_output_format(self, format: str) -> Markdown:
"""
Set the output format for the class instance.
Arguments:
format: Must be a known value in `Markdown.output_formats`.
"""
self.output_format = format.lower().rstrip('145') # ignore number
try:
self.serializer = self.output_formats[self.output_format]
except KeyError as e:
valid_formats = list(self.output_formats.keys())
valid_formats.sort()
message = 'Invalid Output Format: "%s". Use one of %s.' \
% (self.output_format,
'"' + '", "'.join(valid_formats) + '"')
e.args = (message,) + e.args[1:]
raise
return self
# Note: the `tag` argument is type annotated `Any` as ElementTree uses many various objects as tags.
# As there is no standardization in ElementTree, the type of a given tag is unpredictable.
def is_block_level(self, tag: Any) -> bool:
"""
Check if the given `tag` is a block level HTML tag.
Returns `True` for any string listed in `Markdown.block_level_elements`. A `tag` which is
not a string always returns `False`.
"""
if isinstance(tag, str):
return tag.lower().rstrip('/') in self.block_level_elements
# Some ElementTree tags are not strings, so return False.
return False
def convert(self, source: str) -> str:
"""
Convert a Markdown string to a string in the specified output format.
Arguments:
source: Markdown formatted text as Unicode or ASCII string.
Returns:
A string in the specified output format.
Markdown parsing takes place in five steps:
1. A bunch of [`preprocessors`][markdown.preprocessors] munge the input text.
2. A [`BlockParser`][markdown.blockparser.BlockParser] parses the high-level structural elements of the
pre-processed text into an [`ElementTree`][xml.etree.ElementTree.ElementTree] object.
3. A bunch of [`treeprocessors`][markdown.treeprocessors] are run against the
[`ElementTree`][xml.etree.ElementTree.ElementTree] object. One such `treeprocessor`
([`markdown.treeprocessors.InlineProcessor`][]) runs [`inlinepatterns`][markdown.inlinepatterns]
against the [`ElementTree`][xml.etree.ElementTree.ElementTree] object, parsing inline markup.
4. Some [`postprocessors`][markdown.postprocessors] are run against the text after the
[`ElementTree`][xml.etree.ElementTree.ElementTree] object has been serialized into text.
5. The output is returned as a string.
"""
# Fix up the source text
if not source.strip():
return '' # a blank Unicode string
try:
source = str(source)
except UnicodeDecodeError as e: # pragma: no cover
# Customize error message while maintaining original traceback
e.reason += '. -- Note: Markdown only accepts Unicode input!'
raise
# Split into lines and run the line preprocessors.
self.lines = source.split("\n")
for prep in self.preprocessors:
self.lines = prep.run(self.lines)
# Parse the high-level elements.
root = self.parser.parseDocument(self.lines).getroot()
# Run the tree-processors
for treeprocessor in self.treeprocessors:
newRoot = treeprocessor.run(root)
if newRoot is not None:
root = newRoot
# Serialize _properly_. Strip top-level tags.
output = self.serializer(root)
if self.stripTopLevelTags:
try:
start = output.index(
'<%s>' % self.doc_tag) + len(self.doc_tag) + 2
end = output.rindex('</%s>' % self.doc_tag)
output = output[start:end].strip()
except ValueError as e: # pragma: no cover
if output.strip().endswith('<%s />' % self.doc_tag):
# We have an empty document
output = ''
else:
# We have a serious problem
raise ValueError('Markdown failed to strip top-level '
'tags. Document=%r' % output.strip()) from e
# Run the text post-processors
for pp in self.postprocessors:
output = pp.run(output)
return output.strip()
def convertFile(
self,
input: str | TextIO | None = None,
output: str | TextIO | None = None,
encoding: str | None = None,
) -> Markdown:
"""
Converts a Markdown file and returns the HTML as a Unicode string.
Decodes the file using the provided encoding (defaults to `utf-8`),
passes the file content to markdown, and outputs the HTML to either
the provided stream or the file with provided name, using the same
encoding as the source file. The
[`xmlcharrefreplace`](https://docs.python.org/3/library/codecs.html#error-handlers)
error handler is used when encoding the output.
**Note:** This is the only place that decoding and encoding of Unicode
takes place in Python-Markdown. (All other code is Unicode-in /
Unicode-out.)
Arguments:
input: File object or path. Reads from `stdin` if `None`.
output: File object or path. Writes to `stdout` if `None`.
encoding: Encoding of input and output files. Defaults to `utf-8`.
"""
encoding = encoding or "utf-8"
# Read the source
if input:
if isinstance(input, str):
input_file = codecs.open(input, mode="r", encoding=encoding)
else:
input_file = codecs.getreader(encoding)(input)
text = input_file.read()
input_file.close()
else:
text = sys.stdin.read()
if not isinstance(text, str): # pragma: no cover
text = text.decode(encoding)
text = text.lstrip('\ufeff') # remove the byte-order mark
# Convert
html = self.convert(text)
# Write to file or stdout
if output:
if isinstance(output, str):
output_file = codecs.open(output, "w",
encoding=encoding,
errors="xmlcharrefreplace")
output_file.write(html)
output_file.close()
else:
writer = codecs.getwriter(encoding)
output_file = writer(output, errors="xmlcharrefreplace")
output_file.write(html)
# Don't close here. User may want to write more.
else:
# Encode manually and write bytes to stdout.
html = html.encode(encoding, "xmlcharrefreplace")
try:
# Write bytes directly to buffer (Python 3).
sys.stdout.buffer.write(html)
except AttributeError: # pragma: no cover
# Probably Python 2, which works with bytes by default.
sys.stdout.write(html)
return self
"""
EXPORTED FUNCTIONS
=============================================================================
Those are the two functions we really mean to export: `markdown()` and
`markdownFromFile()`.
"""
def markdown(text: str, **kwargs: Any) -> str:
"""
Convert a markdown string to HTML and return HTML as a Unicode string.
This is a shortcut function for [`Markdown`][markdown.Markdown] class to cover the most
basic use case. It initializes an instance of [`Markdown`][markdown.Markdown], loads the
necessary extensions and runs the parser on the given text.
Arguments:
text: Markdown formatted text as Unicode or ASCII string.
Keyword arguments:
**kwargs: Any arguments accepted by the Markdown class.
Returns:
A string in the specified output format.
"""
md = Markdown(**kwargs)
return md.convert(text)
def markdownFromFile(**kwargs: Any):
"""
Read Markdown text from a file and write output to a file or a stream.
This is a shortcut function which initializes an instance of [`Markdown`][markdown.Markdown],
and calls the [`convertFile`][markdown.Markdown.convertFile] method rather than
[`convert`][markdown.Markdown.convert].
Keyword arguments:
input (str | TextIO): A file name or readable object.
output (str | TextIO): A file name or writable object.
encoding (str): Encoding of input and output.
**kwargs: Any arguments accepted by the `Markdown` class.
"""
md = Markdown(**kwargs)
md.convertFile(kwargs.get('input', None),
kwargs.get('output', None),
kwargs.get('encoding', None))

View File

@@ -0,0 +1,145 @@
# Python Markdown
# A Python implementation of John Gruber's Markdown.
# Documentation: https://python-markdown.github.io/
# GitHub: https://github.com/Python-Markdown/markdown/
# PyPI: https://pypi.org/project/Markdown/
# Started by Manfred Stienstra (http://www.dwerg.net/).
# Maintained for a few years by Yuri Takhteyev (http://www.freewisdom.org).
# Currently maintained by Waylan Limberg (https://github.com/waylan),
# Dmitry Shachnev (https://github.com/mitya57) and Isaac Muse (https://github.com/facelessuser).
# Copyright 2007-2023 The Python Markdown Project (v. 1.7 and later)
# Copyright 2004, 2005, 2006 Yuri Takhteyev (v. 0.2-1.6b)
# Copyright 2004 Manfred Stienstra (the original version)
# License: BSD (see LICENSE.md for details).
"""
Markdown accepts an [`Extension`][markdown.extensions.Extension] instance for each extension. Therefore, each extension
must to define a class that extends [`Extension`][markdown.extensions.Extension] and over-rides the
[`extendMarkdown`][markdown.extensions.Extension.extendMarkdown] method. Within this class one can manage configuration
options for their extension and attach the various processors and patterns which make up an extension to the
[`Markdown`][markdown.Markdown] instance.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Mapping, Sequence
from ..util import parseBoolValue
if TYPE_CHECKING: # pragma: no cover
from markdown import Markdown
class Extension:
""" Base class for extensions to subclass. """
config: Mapping[str, list] = {}
"""
Default configuration for an extension.
This attribute is to be defined in a subclass and must be of the following format:
``` python
config = {
'key': ['value', 'description']
}
```
Note that [`setConfig`][markdown.extensions.Extension.setConfig] will raise a [`KeyError`][]
if a default is not set for each option.
"""
def __init__(self, **kwargs):
""" Initiate Extension and set up configs. """
self.setConfigs(kwargs)
def getConfig(self, key: str, default: Any = '') -> Any:
"""
Return a single configuration option value.
Arguments:
key: The configuration option name.
default: Default value to return if key is not set.
Returns:
Value of stored configuration option.
"""
if key in self.config:
return self.config[key][0]
else:
return default
def getConfigs(self) -> dict[str, Any]:
"""
Return all configuration options.
Returns:
All configuration options.
"""
return {key: self.getConfig(key) for key in self.config.keys()}
def getConfigInfo(self) -> list[tuple[str, str]]:
"""
Return descriptions of all configuration options.
Returns:
All descriptions of configuration options.
"""
return [(key, self.config[key][1]) for key in self.config.keys()]
def setConfig(self, key: str, value: Any) -> None:
"""
Set a configuration option.
If the corresponding default value set in [`config`][markdown.extensions.Extension.config]
is a `bool` value or `None`, then `value` is passed through
[`parseBoolValue`][markdown.util.parseBoolValue] before being stored.
Arguments:
key: Name of configuration option to set.
value: Value to assign to option.
Raises:
KeyError: If `key` is not known.
"""
if isinstance(self.config[key][0], bool):
value = parseBoolValue(value)
if self.config[key][0] is None:
value = parseBoolValue(value, preserve_none=True)
self.config[key][0] = value
def setConfigs(self, items: Mapping[str, Any] | Sequence[tuple[str, Any]]):
"""
Loop through a collection of configuration options, passing each to
[`setConfig`][markdown.extensions.Extension.setConfig].
Arguments:
items: Collection of configuration options.
Raises:
KeyError: for any unknown key.
"""
if hasattr(items, 'items'):
# it's a dict
items = items.items()
for key, value in items:
self.setConfig(key, value)
def extendMarkdown(self, md: Markdown) -> None:
"""
Add the various processors and patterns to the Markdown Instance.
This method must be overridden by every extension.
Arguments:
md: The Markdown instance.
"""
raise NotImplementedError(
'Extension "%s.%s" must define an "extendMarkdown"'
'method.' % (self.__class__.__module__, self.__class__.__name__)
)

View File

@@ -0,0 +1,105 @@
# Abbreviation Extension for Python-Markdown
# ==========================================
# This extension adds abbreviation handling to Python-Markdown.
# See https://Python-Markdown.github.io/extensions/abbreviations
# for documentation.
# Original code Copyright 2007-2008 [Waylan Limberg](http://achinghead.com/)
# and [Seemant Kulleen](http://www.kulleen.org/)
# All changes Copyright 2008-2014 The Python Markdown Project
# License: [BSD](https://opensource.org/licenses/bsd-license.php)
"""
This extension adds abbreviation handling to Python-Markdown.
See the [documentation](https://Python-Markdown.github.io/extensions/abbreviations)
for details.
"""
from __future__ import annotations
from . import Extension
from ..blockprocessors import BlockProcessor
from ..inlinepatterns import InlineProcessor
from ..util import AtomicString
import re
import xml.etree.ElementTree as etree
class AbbrExtension(Extension):
""" Abbreviation Extension for Python-Markdown. """
def extendMarkdown(self, md):
""" Insert `AbbrPreprocessor` before `ReferencePreprocessor`. """
md.parser.blockprocessors.register(AbbrPreprocessor(md.parser), 'abbr', 16)
class AbbrPreprocessor(BlockProcessor):
""" Abbreviation Preprocessor - parse text for abbr references. """
RE = re.compile(r'^[*]\[(?P<abbr>[^\]]*)\][ ]?:[ ]*\n?[ ]*(?P<title>.*)$', re.MULTILINE)
def test(self, parent, block):
return True
def run(self, parent, blocks):
"""
Find and remove all Abbreviation references from the text.
Each reference is set as a new `AbbrPattern` in the markdown instance.
"""
block = blocks.pop(0)
m = self.RE.search(block)
if m:
abbr = m.group('abbr').strip()
title = m.group('title').strip()
self.parser.md.inlinePatterns.register(
AbbrInlineProcessor(self._generate_pattern(abbr), title), 'abbr-%s' % abbr, 2
)
if block[m.end():].strip():
# Add any content after match back to blocks as separate block
blocks.insert(0, block[m.end():].lstrip('\n'))
if block[:m.start()].strip():
# Add any content before match back to blocks as separate block
blocks.insert(0, block[:m.start()].rstrip('\n'))
return True
# No match. Restore block.
blocks.insert(0, block)
return False
def _generate_pattern(self, text):
"""
Given a string, returns an regex pattern to match that string.
'HTML' -> r'(?P<abbr>[H][T][M][L])'
Note: we force each char as a literal match (in brackets) as we don't
know what they will be beforehand.
"""
chars = list(text)
for i in range(len(chars)):
chars[i] = r'[%s]' % chars[i]
return r'(?P<abbr>\b%s\b)' % (r''.join(chars))
class AbbrInlineProcessor(InlineProcessor):
""" Abbreviation inline pattern. """
def __init__(self, pattern, title):
super().__init__(pattern)
self.title = title
def handleMatch(self, m, data):
abbr = etree.Element('abbr')
abbr.text = AtomicString(m.group('abbr'))
abbr.set('title', self.title)
return abbr, m.start(0), m.end(0)
def makeExtension(**kwargs): # pragma: no cover
return AbbrExtension(**kwargs)

View File

@@ -0,0 +1,179 @@
# Admonition extension for Python-Markdown
# ========================================
# Adds rST-style admonitions. Inspired by [rST][] feature with the same name.
# [rST]: http://docutils.sourceforge.net/docs/ref/rst/directives.html#specific-admonitions
# See https://Python-Markdown.github.io/extensions/admonition
# for documentation.
# Original code Copyright [Tiago Serafim](https://www.tiagoserafim.com/).
# All changes Copyright The Python Markdown Project
# License: [BSD](https://opensource.org/licenses/bsd-license.php)
"""
Adds rST-style admonitions. Inspired by [rST][] feature with the same name.
[rST]: http://docutils.sourceforge.net/docs/ref/rst/directives.html#specific-admonitions
See the [documentation](https://Python-Markdown.github.io/extensions/admonition)
for details.
"""
from __future__ import annotations
from . import Extension
from ..blockprocessors import BlockProcessor
import xml.etree.ElementTree as etree
import re
class AdmonitionExtension(Extension):
""" Admonition extension for Python-Markdown. """
def extendMarkdown(self, md):
""" Add Admonition to Markdown instance. """
md.registerExtension(self)
md.parser.blockprocessors.register(AdmonitionProcessor(md.parser), 'admonition', 105)
class AdmonitionProcessor(BlockProcessor):
CLASSNAME = 'admonition'
CLASSNAME_TITLE = 'admonition-title'
RE = re.compile(r'(?:^|\n)!!! ?([\w\-]+(?: +[\w\-]+)*)(?: +"(.*?)")? *(?:\n|$)')
RE_SPACES = re.compile(' +')
def __init__(self, parser):
"""Initialization."""
super().__init__(parser)
self.current_sibling = None
self.content_indention = 0
def parse_content(self, parent, block):
"""Get sibling admonition.
Retrieve the appropriate sibling element. This can get tricky when
dealing with lists.
"""
old_block = block
the_rest = ''
# We already acquired the block via test
if self.current_sibling is not None:
sibling = self.current_sibling
block, the_rest = self.detab(block, self.content_indent)
self.current_sibling = None
self.content_indent = 0
return sibling, block, the_rest
sibling = self.lastChild(parent)
if sibling is None or sibling.tag != 'div' or sibling.get('class', '').find(self.CLASSNAME) == -1:
sibling = None
else:
# If the last child is a list and the content is sufficiently indented
# to be under it, then the content's sibling is in the list.
last_child = self.lastChild(sibling)
indent = 0
while last_child is not None:
if (
sibling is not None and block.startswith(' ' * self.tab_length * 2) and
last_child is not None and last_child.tag in ('ul', 'ol', 'dl')
):
# The expectation is that we'll find an `<li>` or `<dt>`.
# We should get its last child as well.
sibling = self.lastChild(last_child)
last_child = self.lastChild(sibling) if sibling is not None else None
# Context has been lost at this point, so we must adjust the
# text's indentation level so it will be evaluated correctly
# under the list.
block = block[self.tab_length:]
indent += self.tab_length
else:
last_child = None
if not block.startswith(' ' * self.tab_length):
sibling = None
if sibling is not None:
indent += self.tab_length
block, the_rest = self.detab(old_block, indent)
self.current_sibling = sibling
self.content_indent = indent
return sibling, block, the_rest
def test(self, parent, block):
if self.RE.search(block):
return True
else:
return self.parse_content(parent, block)[0] is not None
def run(self, parent, blocks):
block = blocks.pop(0)
m = self.RE.search(block)
if m:
if m.start() > 0:
self.parser.parseBlocks(parent, [block[:m.start()]])
block = block[m.end():] # removes the first line
block, theRest = self.detab(block)
else:
sibling, block, theRest = self.parse_content(parent, block)
if m:
klass, title = self.get_class_and_title(m)
div = etree.SubElement(parent, 'div')
div.set('class', '{} {}'.format(self.CLASSNAME, klass))
if title:
p = etree.SubElement(div, 'p')
p.text = title
p.set('class', self.CLASSNAME_TITLE)
else:
# Sibling is a list item, but we need to wrap it's content should be wrapped in <p>
if sibling.tag in ('li', 'dd') and sibling.text:
text = sibling.text
sibling.text = ''
p = etree.SubElement(sibling, 'p')
p.text = text
div = sibling
self.parser.parseChunk(div, block)
if theRest:
# This block contained unindented line(s) after the first indented
# line. Insert these lines as the first block of the master blocks
# list for future processing.
blocks.insert(0, theRest)
def get_class_and_title(self, match):
klass, title = match.group(1).lower(), match.group(2)
klass = self.RE_SPACES.sub(' ', klass)
if title is None:
# no title was provided, use the capitalized class name as title
# e.g.: `!!! note` will render
# `<p class="admonition-title">Note</p>`
title = klass.split(' ', 1)[0].capitalize()
elif title == '':
# an explicit blank title should not be rendered
# e.g.: `!!! warning ""` will *not* render `p` with a title
title = None
return klass, title
def makeExtension(**kwargs): # pragma: no cover
return AdmonitionExtension(**kwargs)

View File

@@ -0,0 +1,179 @@
# Attribute List Extension for Python-Markdown
# ============================================
# Adds attribute list syntax. Inspired by
# [Maruku](http://maruku.rubyforge.org/proposal.html#attribute_lists)'s
# feature of the same name.
# See https://Python-Markdown.github.io/extensions/attr_list
# for documentation.
# Original code Copyright 2011 [Waylan Limberg](http://achinghead.com/).
# All changes Copyright 2011-2014 The Python Markdown Project
# License: [BSD](https://opensource.org/licenses/bsd-license.php)
"""
Adds attribute list syntax. Inspired by
[Maruku](http://maruku.rubyforge.org/proposal.html#attribute_lists)'s
feature of the same name.
See the [documentation](https://Python-Markdown.github.io/extensions/attr_list)
for details.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
from . import Extension
from ..treeprocessors import Treeprocessor
import re
if TYPE_CHECKING: # pragma: no cover
from xml.etree.ElementTree import Element
def _handle_double_quote(s, t):
k, v = t.split('=', 1)
return k, v.strip('"')
def _handle_single_quote(s, t):
k, v = t.split('=', 1)
return k, v.strip("'")
def _handle_key_value(s, t):
return t.split('=', 1)
def _handle_word(s, t):
if t.startswith('.'):
return '.', t[1:]
if t.startswith('#'):
return 'id', t[1:]
return t, t
_scanner = re.Scanner([
(r'[^ =]+=".*?"', _handle_double_quote),
(r"[^ =]+='.*?'", _handle_single_quote),
(r'[^ =]+=[^ =]+', _handle_key_value),
(r'[^ =]+', _handle_word),
(r' ', None)
])
def get_attrs(str: str) -> list[tuple[str, str]]:
""" Parse attribute list and return a list of attribute tuples. """
return _scanner.scan(str)[0]
def isheader(elem: Element) -> bool:
return elem.tag in ['h1', 'h2', 'h3', 'h4', 'h5', 'h6']
class AttrListTreeprocessor(Treeprocessor):
BASE_RE = r'\{\:?[ ]*([^\}\n ][^\}\n]*)[ ]*\}'
HEADER_RE = re.compile(r'[ ]+{}[ ]*$'.format(BASE_RE))
BLOCK_RE = re.compile(r'\n[ ]*{}[ ]*$'.format(BASE_RE))
INLINE_RE = re.compile(r'^{}'.format(BASE_RE))
NAME_RE = re.compile(r'[^A-Z_a-z\u00c0-\u00d6\u00d8-\u00f6\u00f8-\u02ff'
r'\u0370-\u037d\u037f-\u1fff\u200c-\u200d'
r'\u2070-\u218f\u2c00-\u2fef\u3001-\ud7ff'
r'\uf900-\ufdcf\ufdf0-\ufffd'
r'\:\-\.0-9\u00b7\u0300-\u036f\u203f-\u2040]+')
def run(self, doc: Element):
for elem in doc.iter():
if self.md.is_block_level(elem.tag):
# Block level: check for `attrs` on last line of text
RE = self.BLOCK_RE
if isheader(elem) or elem.tag in ['dt', 'td', 'th']:
# header, def-term, or table cell: check for attributes at end of element
RE = self.HEADER_RE
if len(elem) and elem.tag == 'li':
# special case list items. children may include a `ul` or `ol`.
pos = None
# find the `ul` or `ol` position
for i, child in enumerate(elem):
if child.tag in ['ul', 'ol']:
pos = i
break
if pos is None and elem[-1].tail:
# use tail of last child. no `ul` or `ol`.
m = RE.search(elem[-1].tail)
if m:
self.assign_attrs(elem, m.group(1))
elem[-1].tail = elem[-1].tail[:m.start()]
elif pos is not None and pos > 0 and elem[pos-1].tail:
# use tail of last child before `ul` or `ol`
m = RE.search(elem[pos-1].tail)
if m:
self.assign_attrs(elem, m.group(1))
elem[pos-1].tail = elem[pos-1].tail[:m.start()]
elif elem.text:
# use text. `ul` is first child.
m = RE.search(elem.text)
if m:
self.assign_attrs(elem, m.group(1))
elem.text = elem.text[:m.start()]
elif len(elem) and elem[-1].tail:
# has children. Get from tail of last child
m = RE.search(elem[-1].tail)
if m:
self.assign_attrs(elem, m.group(1))
elem[-1].tail = elem[-1].tail[:m.start()]
if isheader(elem):
# clean up trailing #s
elem[-1].tail = elem[-1].tail.rstrip('#').rstrip()
elif elem.text:
# no children. Get from text.
m = RE.search(elem.text)
if m:
self.assign_attrs(elem, m.group(1))
elem.text = elem.text[:m.start()]
if isheader(elem):
# clean up trailing #s
elem.text = elem.text.rstrip('#').rstrip()
else:
# inline: check for `attrs` at start of tail
if elem.tail:
m = self.INLINE_RE.match(elem.tail)
if m:
self.assign_attrs(elem, m.group(1))
elem.tail = elem.tail[m.end():]
def assign_attrs(self, elem: Element, attrs: str) -> None:
""" Assign `attrs` to element. """
for k, v in get_attrs(attrs):
if k == '.':
# add to class
cls = elem.get('class')
if cls:
elem.set('class', '{} {}'.format(cls, v))
else:
elem.set('class', v)
else:
# assign attribute `k` with `v`
elem.set(self.sanitize_name(k), v)
def sanitize_name(self, name: str) -> str:
"""
Sanitize name as 'an XML Name, minus the ":"'.
See https://www.w3.org/TR/REC-xml-names/#NT-NCName
"""
return self.NAME_RE.sub('_', name)
class AttrListExtension(Extension):
""" Attribute List extension for Python-Markdown """
def extendMarkdown(self, md):
md.treeprocessors.register(AttrListTreeprocessor(md), 'attr_list', 8)
md.registerExtension(self)
def makeExtension(**kwargs): # pragma: no cover
return AttrListExtension(**kwargs)

View File

@@ -0,0 +1,338 @@
# CodeHilite Extension for Python-Markdown
# ========================================
# Adds code/syntax highlighting to standard Python-Markdown code blocks.
# See https://Python-Markdown.github.io/extensions/code_hilite
# for documentation.
# Original code Copyright 2006-2008 [Waylan Limberg](http://achinghead.com/).
# All changes Copyright 2008-2014 The Python Markdown Project
# License: [BSD](https://opensource.org/licenses/bsd-license.php)
"""
Adds code/syntax highlighting to standard Python-Markdown code blocks.
See the [documentation](https://Python-Markdown.github.io/extensions/code_hilite)
for details.
"""
from __future__ import annotations
from . import Extension
from ..treeprocessors import Treeprocessor
from ..util import parseBoolValue
try: # pragma: no cover
from pygments import highlight
from pygments.lexers import get_lexer_by_name, guess_lexer
from pygments.formatters import get_formatter_by_name
from pygments.util import ClassNotFound
pygments = True
except ImportError: # pragma: no cover
pygments = False
def parse_hl_lines(expr: str) -> list[int]:
"""Support our syntax for emphasizing certain lines of code.
`expr` should be like '1 2' to emphasize lines 1 and 2 of a code block.
Returns a list of integers, the line numbers to emphasize.
"""
if not expr:
return []
try:
return list(map(int, expr.split()))
except ValueError: # pragma: no cover
return []
# ------------------ The Main CodeHilite Class ----------------------
class CodeHilite:
"""
Determine language of source code, and pass it on to the Pygments highlighter.
Usage:
```python
code = CodeHilite(src=some_code, lang='python')
html = code.hilite()
```
Arguments:
src: Source string or any object with a `.readline` attribute.
Keyword arguments:
lang (str): String name of Pygments lexer to use for highlighting. Default: `None`.
guess_lang (bool): Auto-detect which lexer to use.
Ignored if `lang` is set to a valid value. Default: `True`.
use_pygments (bool): Pass code to Pygments for code highlighting. If `False`, the code is
instead wrapped for highlighting by a JavaScript library. Default: `True`.
pygments_formatter (str): The name of a Pygments formatter or a formatter class used for
highlighting the code blocks. Default: `html`.
linenums (bool): An alias to Pygments `linenos` formatter option. Default: `None`.
css_class (str): An alias to Pygments `cssclass` formatter option. Default: 'codehilite'.
lang_prefix (str): Prefix prepended to the language. Default: "language-".
Other Options:
Any other options are accepted and passed on to the lexer and formatter. Therefore,
valid options include any options which are accepted by the `html` formatter or
whichever lexer the code's language uses. Note that most lexers do not have any
options. However, a few have very useful options, such as PHP's `startinline` option.
Any invalid options are ignored without error.
* **Formatter options**: <https://pygments.org/docs/formatters/#HtmlFormatter>
* **Lexer Options**: <https://pygments.org/docs/lexers/>
Additionally, when Pygments is enabled, the code's language is passed to the
formatter as an extra option `lang_str`, whose value being `{lang_prefix}{lang}`.
This option has no effect to the Pygments' builtin formatters.
Advanced Usage:
```python
code = CodeHilite(
src = some_code,
lang = 'php',
startinline = True, # Lexer option. Snippet does not start with `<?php`.
linenostart = 42, # Formatter option. Snippet starts on line 42.
hl_lines = [45, 49, 50], # Formatter option. Highlight lines 45, 49, and 50.
linenos = 'inline' # Formatter option. Avoid alignment problems.
)
html = code.hilite()
```
"""
def __init__(self, src: str, **options):
self.src = src
self.lang = options.pop('lang', None)
self.guess_lang = options.pop('guess_lang', True)
self.use_pygments = options.pop('use_pygments', True)
self.lang_prefix = options.pop('lang_prefix', 'language-')
self.pygments_formatter = options.pop('pygments_formatter', 'html')
if 'linenos' not in options:
options['linenos'] = options.pop('linenums', None)
if 'cssclass' not in options:
options['cssclass'] = options.pop('css_class', 'codehilite')
if 'wrapcode' not in options:
# Override Pygments default
options['wrapcode'] = True
# Disallow use of `full` option
options['full'] = False
self.options = options
def hilite(self, shebang=True) -> str:
"""
Pass code to the [Pygments](https://pygments.org/) highlighter with
optional line numbers. The output should then be styled with CSS to
your liking. No styles are applied by default - only styling hooks
(i.e.: `<span class="k">`).
returns : A string of html.
"""
self.src = self.src.strip('\n')
if self.lang is None and shebang:
self._parseHeader()
if pygments and self.use_pygments:
try:
lexer = get_lexer_by_name(self.lang, **self.options)
except ValueError:
try:
if self.guess_lang:
lexer = guess_lexer(self.src, **self.options)
else:
lexer = get_lexer_by_name('text', **self.options)
except ValueError: # pragma: no cover
lexer = get_lexer_by_name('text', **self.options)
if not self.lang:
# Use the guessed lexer's language instead
self.lang = lexer.aliases[0]
lang_str = f'{self.lang_prefix}{self.lang}'
if isinstance(self.pygments_formatter, str):
try:
formatter = get_formatter_by_name(self.pygments_formatter, **self.options)
except ClassNotFound:
formatter = get_formatter_by_name('html', **self.options)
else:
formatter = self.pygments_formatter(lang_str=lang_str, **self.options)
return highlight(self.src, lexer, formatter)
else:
# just escape and build markup usable by JavaScript highlighting libraries
txt = self.src.replace('&', '&amp;')
txt = txt.replace('<', '&lt;')
txt = txt.replace('>', '&gt;')
txt = txt.replace('"', '&quot;')
classes = []
if self.lang:
classes.append('{}{}'.format(self.lang_prefix, self.lang))
if self.options['linenos']:
classes.append('linenums')
class_str = ''
if classes:
class_str = ' class="{}"'.format(' '.join(classes))
return '<pre class="{}"><code{}>{}\n</code></pre>\n'.format(
self.options['cssclass'],
class_str,
txt
)
def _parseHeader(self):
"""
Determines language of a code block from shebang line and whether the
said line should be removed or left in place. If the shebang line
contains a path (even a single /) then it is assumed to be a real
shebang line and left alone. However, if no path is given
(e.i.: `#!python` or `:::python`) then it is assumed to be a mock shebang
for language identification of a code fragment and removed from the
code block prior to processing for code highlighting. When a mock
shebang (e.i: `#!python`) is found, line numbering is turned on. When
colons are found in place of a shebang (e.i.: `:::python`), line
numbering is left in the current state - off by default.
Also parses optional list of highlight lines, like:
:::python hl_lines="1 3"
"""
import re
# split text into lines
lines = self.src.split("\n")
# pull first line to examine
fl = lines.pop(0)
c = re.compile(r'''
(?:(?:^::+)|(?P<shebang>^[#]!)) # Shebang or 2 or more colons
(?P<path>(?:/\w+)*[/ ])? # Zero or 1 path
(?P<lang>[\w#.+-]*) # The language
\s* # Arbitrary whitespace
# Optional highlight lines, single- or double-quote-delimited
(hl_lines=(?P<quot>"|')(?P<hl_lines>.*?)(?P=quot))?
''', re.VERBOSE)
# search first line for shebang
m = c.search(fl)
if m:
# we have a match
try:
self.lang = m.group('lang').lower()
except IndexError: # pragma: no cover
self.lang = None
if m.group('path'):
# path exists - restore first line
lines.insert(0, fl)
if self.options['linenos'] is None and m.group('shebang'):
# Overridable and Shebang exists - use line numbers
self.options['linenos'] = True
self.options['hl_lines'] = parse_hl_lines(m.group('hl_lines'))
else:
# No match
lines.insert(0, fl)
self.src = "\n".join(lines).strip("\n")
# ------------------ The Markdown Extension -------------------------------
class HiliteTreeprocessor(Treeprocessor):
""" Highlight source code in code blocks. """
def code_unescape(self, text):
"""Unescape code."""
text = text.replace("&lt;", "<")
text = text.replace("&gt;", ">")
# Escaped '&' should be replaced at the end to avoid
# conflicting with < and >.
text = text.replace("&amp;", "&")
return text
def run(self, root):
""" Find code blocks and store in `htmlStash`. """
blocks = root.iter('pre')
for block in blocks:
if len(block) == 1 and block[0].tag == 'code':
local_config = self.config.copy()
code = CodeHilite(
self.code_unescape(block[0].text),
tab_length=self.md.tab_length,
style=local_config.pop('pygments_style', 'default'),
**local_config
)
placeholder = self.md.htmlStash.store(code.hilite())
# Clear code block in `etree` instance
block.clear()
# Change to `p` element which will later
# be removed when inserting raw html
block.tag = 'p'
block.text = placeholder
class CodeHiliteExtension(Extension):
""" Add source code highlighting to markdown code blocks. """
def __init__(self, **kwargs):
# define default configs
self.config = {
'linenums': [
None, "Use lines numbers. True|table|inline=yes, False=no, None=auto. Default: `None`."
],
'guess_lang': [
True, "Automatic language detection - Default: `True`."
],
'css_class': [
"codehilite", "Set class name for wrapper <div> - Default: `codehilite`."
],
'pygments_style': [
'default', 'Pygments HTML Formatter Style (Colorscheme). Default: `default`.'
],
'noclasses': [
False, 'Use inline styles instead of CSS classes - Default `False`.'
],
'use_pygments': [
True, 'Highlight code blocks with pygments. Disable if using a JavaScript library. Default: `True`.'
],
'lang_prefix': [
'language-', 'Prefix prepended to the language when `use_pygments` is false. Default: `language-`.'
],
'pygments_formatter': [
'html', 'Use a specific formatter for Pygments highlighting. Default: `html`.'
],
}
""" Default configuration options. """
for key, value in kwargs.items():
if key in self.config:
self.setConfig(key, value)
else:
# manually set unknown keywords.
if isinstance(value, str):
try:
# Attempt to parse `str` as a boolean value
value = parseBoolValue(value, preserve_none=True)
except ValueError:
pass # Assume it's not a boolean value. Use as-is.
self.config[key] = [value, '']
def extendMarkdown(self, md):
""" Add `HilitePostprocessor` to Markdown instance. """
hiliter = HiliteTreeprocessor(md)
hiliter.config = self.getConfigs()
md.treeprocessors.register(hiliter, 'hilite', 30)
md.registerExtension(self)
def makeExtension(**kwargs): # pragma: no cover
return CodeHiliteExtension(**kwargs)

View File

@@ -0,0 +1,119 @@
# Definition List Extension for Python-Markdown
# =============================================
# Adds parsing of Definition Lists to Python-Markdown.
# See https://Python-Markdown.github.io/extensions/definition_lists
# for documentation.
# Original code Copyright 2008 [Waylan Limberg](http://achinghead.com)
# All changes Copyright 2008-2014 The Python Markdown Project
# License: [BSD](https://opensource.org/licenses/bsd-license.php)
"""
Adds parsing of Definition Lists to Python-Markdown.
See the [documentation](https://Python-Markdown.github.io/extensions/definition_lists)
for details.
"""
from __future__ import annotations
from . import Extension
from ..blockprocessors import BlockProcessor, ListIndentProcessor
import xml.etree.ElementTree as etree
import re
class DefListProcessor(BlockProcessor):
""" Process Definition Lists. """
RE = re.compile(r'(^|\n)[ ]{0,3}:[ ]{1,3}(.*?)(\n|$)')
NO_INDENT_RE = re.compile(r'^[ ]{0,3}[^ :]')
def test(self, parent, block):
return bool(self.RE.search(block))
def run(self, parent, blocks):
raw_block = blocks.pop(0)
m = self.RE.search(raw_block)
terms = [term.strip() for term in
raw_block[:m.start()].split('\n') if term.strip()]
block = raw_block[m.end():]
no_indent = self.NO_INDENT_RE.match(block)
if no_indent:
d, theRest = (block, None)
else:
d, theRest = self.detab(block)
if d:
d = '{}\n{}'.format(m.group(2), d)
else:
d = m.group(2)
sibling = self.lastChild(parent)
if not terms and sibling is None:
# This is not a definition item. Most likely a paragraph that
# starts with a colon at the beginning of a document or list.
blocks.insert(0, raw_block)
return False
if not terms and sibling.tag == 'p':
# The previous paragraph contains the terms
state = 'looselist'
terms = sibling.text.split('\n')
parent.remove(sibling)
# Acquire new sibling
sibling = self.lastChild(parent)
else:
state = 'list'
if sibling is not None and sibling.tag == 'dl':
# This is another item on an existing list
dl = sibling
if not terms and len(dl) and dl[-1].tag == 'dd' and len(dl[-1]):
state = 'looselist'
else:
# This is a new list
dl = etree.SubElement(parent, 'dl')
# Add terms
for term in terms:
dt = etree.SubElement(dl, 'dt')
dt.text = term
# Add definition
self.parser.state.set(state)
dd = etree.SubElement(dl, 'dd')
self.parser.parseBlocks(dd, [d])
self.parser.state.reset()
if theRest:
blocks.insert(0, theRest)
class DefListIndentProcessor(ListIndentProcessor):
""" Process indented children of definition list items. """
# Definition lists need to be aware of all list types
ITEM_TYPES = ['dd', 'li']
""" Include `dd` in list item types. """
LIST_TYPES = ['dl', 'ol', 'ul']
""" Include `dl` is list types. """
def create_item(self, parent, block):
""" Create a new `dd` or `li` (depending on parent) and parse the block with it as the parent. """
dd = etree.SubElement(parent, 'dd')
self.parser.parseBlocks(dd, [block])
class DefListExtension(Extension):
""" Add definition lists to Markdown. """
def extendMarkdown(self, md):
""" Add an instance of `DefListProcessor` to `BlockParser`. """
md.parser.blockprocessors.register(DefListIndentProcessor(md.parser), 'defindent', 85)
md.parser.blockprocessors.register(DefListProcessor(md.parser), 'deflist', 25)
def makeExtension(**kwargs): # pragma: no cover
return DefListExtension(**kwargs)

View File

@@ -0,0 +1,66 @@
# Python-Markdown Extra Extension
# ===============================
# A compilation of various Python-Markdown extensions that imitates
# [PHP Markdown Extra](http://michelf.com/projects/php-markdown/extra/).
# See https://Python-Markdown.github.io/extensions/extra
# for documentation.
# Copyright The Python Markdown Project
# License: [BSD](https://opensource.org/licenses/bsd-license.php)
"""
A compilation of various Python-Markdown extensions that imitates
[PHP Markdown Extra](http://michelf.com/projects/php-markdown/extra/).
Note that each of the individual extensions still need to be available
on your `PYTHONPATH`. This extension simply wraps them all up as a
convenience so that only one extension needs to be listed when
initiating Markdown. See the documentation for each individual
extension for specifics about that extension.
There may be additional extensions that are distributed with
Python-Markdown that are not included here in Extra. Those extensions
are not part of PHP Markdown Extra, and therefore, not part of
Python-Markdown Extra. If you really would like Extra to include
additional extensions, we suggest creating your own clone of Extra
under a different name. You could also edit the `extensions` global
variable defined below, but be aware that such changes may be lost
when you upgrade to any future version of Python-Markdown.
See the [documentation](https://Python-Markdown.github.io/extensions/extra)
for details.
"""
from __future__ import annotations
from . import Extension
extensions = [
'fenced_code',
'footnotes',
'attr_list',
'def_list',
'tables',
'abbr',
'md_in_html'
]
""" The list of included extensions. """
class ExtraExtension(Extension):
""" Add various extensions to Markdown class."""
def __init__(self, **kwargs):
""" `config` is a dumb holder which gets passed to the actual extension later. """
self.config = kwargs
def extendMarkdown(self, md):
""" Register extension instances. """
md.registerExtensions(extensions, self.config)
def makeExtension(**kwargs): # pragma: no cover
return ExtraExtension(**kwargs)

View File

@@ -0,0 +1,182 @@
# Fenced Code Extension for Python Markdown
# =========================================
# This extension adds Fenced Code Blocks to Python-Markdown.
# See https://Python-Markdown.github.io/extensions/fenced_code_blocks
# for documentation.
# Original code Copyright 2007-2008 [Waylan Limberg](http://achinghead.com/).
# All changes Copyright 2008-2014 The Python Markdown Project
# License: [BSD](https://opensource.org/licenses/bsd-license.php)
"""
This extension adds Fenced Code Blocks to Python-Markdown.
See the [documentation](https://Python-Markdown.github.io/extensions/fenced_code_blocks)
for details.
"""
from __future__ import annotations
from textwrap import dedent
from . import Extension
from ..preprocessors import Preprocessor
from .codehilite import CodeHilite, CodeHiliteExtension, parse_hl_lines
from .attr_list import get_attrs, AttrListExtension
from ..util import parseBoolValue
from ..serializers import _escape_attrib_html
import re
class FencedCodeExtension(Extension):
def __init__(self, **kwargs):
self.config = {
'lang_prefix': ['language-', 'Prefix prepended to the language. Default: "language-"']
}
""" Default configuration options. """
super().__init__(**kwargs)
def extendMarkdown(self, md):
""" Add `FencedBlockPreprocessor` to the Markdown instance. """
md.registerExtension(self)
md.preprocessors.register(FencedBlockPreprocessor(md, self.getConfigs()), 'fenced_code_block', 25)
class FencedBlockPreprocessor(Preprocessor):
""" Find and extract fenced code blocks. """
FENCED_BLOCK_RE = re.compile(
dedent(r'''
(?P<fence>^(?:~{3,}|`{3,}))[ ]* # opening fence
((\{(?P<attrs>[^\}\n]*)\})| # (optional {attrs} or
(\.?(?P<lang>[\w#.+-]*)[ ]*)? # optional (.)lang
(hl_lines=(?P<quot>"|')(?P<hl_lines>.*?)(?P=quot)[ ]*)?) # optional hl_lines)
\n # newline (end of opening fence)
(?P<code>.*?)(?<=\n) # the code block
(?P=fence)[ ]*$ # closing fence
'''),
re.MULTILINE | re.DOTALL | re.VERBOSE
)
def __init__(self, md, config):
super().__init__(md)
self.config = config
self.checked_for_deps = False
self.codehilite_conf = {}
self.use_attr_list = False
# List of options to convert to boolean values
self.bool_options = [
'linenums',
'guess_lang',
'noclasses',
'use_pygments'
]
def run(self, lines):
""" Match and store Fenced Code Blocks in the `HtmlStash`. """
# Check for dependent extensions
if not self.checked_for_deps:
for ext in self.md.registeredExtensions:
if isinstance(ext, CodeHiliteExtension):
self.codehilite_conf = ext.getConfigs()
if isinstance(ext, AttrListExtension):
self.use_attr_list = True
self.checked_for_deps = True
text = "\n".join(lines)
while 1:
m = self.FENCED_BLOCK_RE.search(text)
if m:
lang, id, classes, config = None, '', [], {}
if m.group('attrs'):
id, classes, config = self.handle_attrs(get_attrs(m.group('attrs')))
if len(classes):
lang = classes.pop(0)
else:
if m.group('lang'):
lang = m.group('lang')
if m.group('hl_lines'):
# Support `hl_lines` outside of `attrs` for backward-compatibility
config['hl_lines'] = parse_hl_lines(m.group('hl_lines'))
# If `config` is not empty, then the `codehighlite` extension
# is enabled, so we call it to highlight the code
if self.codehilite_conf and self.codehilite_conf['use_pygments'] and config.get('use_pygments', True):
local_config = self.codehilite_conf.copy()
local_config.update(config)
# Combine classes with `cssclass`. Ensure `cssclass` is at end
# as Pygments appends a suffix under certain circumstances.
# Ignore ID as Pygments does not offer an option to set it.
if classes:
local_config['css_class'] = '{} {}'.format(
' '.join(classes),
local_config['css_class']
)
highliter = CodeHilite(
m.group('code'),
lang=lang,
style=local_config.pop('pygments_style', 'default'),
**local_config
)
code = highliter.hilite(shebang=False)
else:
id_attr = lang_attr = class_attr = kv_pairs = ''
if lang:
prefix = self.config.get('lang_prefix', 'language-')
lang_attr = f' class="{prefix}{_escape_attrib_html(lang)}"'
if classes:
class_attr = f' class="{_escape_attrib_html(" ".join(classes))}"'
if id:
id_attr = f' id="{_escape_attrib_html(id)}"'
if self.use_attr_list and config and not config.get('use_pygments', False):
# Only assign key/value pairs to code element if `attr_list` extension is enabled, key/value
# pairs were defined on the code block, and the `use_pygments` key was not set to `True`. The
# `use_pygments` key could be either set to `False` or not defined. It is omitted from output.
kv_pairs = ''.join(
f' {k}="{_escape_attrib_html(v)}"' for k, v in config.items() if k != 'use_pygments'
)
code = self._escape(m.group('code'))
code = f'<pre{id_attr}{class_attr}><code{lang_attr}{kv_pairs}>{code}</code></pre>'
placeholder = self.md.htmlStash.store(code)
text = f'{text[:m.start()]}\n{placeholder}\n{text[m.end():]}'
else:
break
return text.split("\n")
def handle_attrs(self, attrs):
""" Return tuple: `(id, [list, of, classes], {configs})` """
id = ''
classes = []
configs = {}
for k, v in attrs:
if k == 'id':
id = v
elif k == '.':
classes.append(v)
elif k == 'hl_lines':
configs[k] = parse_hl_lines(v)
elif k in self.bool_options:
configs[k] = parseBoolValue(v, fail_on_errors=False, preserve_none=True)
else:
configs[k] = v
return id, classes, configs
def _escape(self, txt):
""" basic html escaping """
txt = txt.replace('&', '&amp;')
txt = txt.replace('<', '&lt;')
txt = txt.replace('>', '&gt;')
txt = txt.replace('"', '&quot;')
return txt
def makeExtension(**kwargs): # pragma: no cover
return FencedCodeExtension(**kwargs)

View File

@@ -0,0 +1,416 @@
# Footnotes Extension for Python-Markdown
# =======================================
# Adds footnote handling to Python-Markdown.
# See https://Python-Markdown.github.io/extensions/footnotes
# for documentation.
# Copyright The Python Markdown Project
# License: [BSD](https://opensource.org/licenses/bsd-license.php)
"""
Adds footnote handling to Python-Markdown.
See the [documentation](https://Python-Markdown.github.io/extensions/footnotes)
for details.
"""
from __future__ import annotations
from . import Extension
from ..blockprocessors import BlockProcessor
from ..inlinepatterns import InlineProcessor
from ..treeprocessors import Treeprocessor
from ..postprocessors import Postprocessor
from .. import util
from collections import OrderedDict
import re
import copy
import xml.etree.ElementTree as etree
FN_BACKLINK_TEXT = util.STX + "zz1337820767766393qq" + util.ETX
NBSP_PLACEHOLDER = util.STX + "qq3936677670287331zz" + util.ETX
RE_REF_ID = re.compile(r'(fnref)(\d+)')
class FootnoteExtension(Extension):
""" Footnote Extension. """
def __init__(self, **kwargs):
""" Setup configs. """
self.config = {
'PLACE_MARKER': [
'///Footnotes Go Here///', 'The text string that marks where the footnotes go'
],
'UNIQUE_IDS': [
False, 'Avoid name collisions across multiple calls to `reset()`.'
],
'BACKLINK_TEXT': [
'&#8617;', "The text string that links from the footnote to the reader's place."
],
'SUPERSCRIPT_TEXT': [
'{}', "The text string that links from the reader's place to the footnote."
],
'BACKLINK_TITLE': [
'Jump back to footnote %d in the text',
'The text string used for the title HTML attribute of the backlink. '
'%d will be replaced by the footnote number.'
],
'SEPARATOR': [
':', 'Footnote separator.'
]
}
""" Default configuration options. """
super().__init__(**kwargs)
# In multiple invocations, emit links that don't get tangled.
self.unique_prefix = 0
self.found_refs = {}
self.used_refs = set()
self.reset()
def extendMarkdown(self, md):
""" Add pieces to Markdown. """
md.registerExtension(self)
self.parser = md.parser
self.md = md
# Insert a `blockprocessor` before `ReferencePreprocessor`
md.parser.blockprocessors.register(FootnoteBlockProcessor(self), 'footnote', 17)
# Insert an inline pattern before `ImageReferencePattern`
FOOTNOTE_RE = r'\[\^([^\]]*)\]' # blah blah [^1] blah
md.inlinePatterns.register(FootnoteInlineProcessor(FOOTNOTE_RE, self), 'footnote', 175)
# Insert a tree-processor that would actually add the footnote div
# This must be before all other tree-processors (i.e., `inline` and
# `codehilite`) so they can run on the the contents of the div.
md.treeprocessors.register(FootnoteTreeprocessor(self), 'footnote', 50)
# Insert a tree-processor that will run after inline is done.
# In this tree-processor we want to check our duplicate footnote tracker
# And add additional `backrefs` to the footnote pointing back to the
# duplicated references.
md.treeprocessors.register(FootnotePostTreeprocessor(self), 'footnote-duplicate', 15)
# Insert a postprocessor after amp_substitute processor
md.postprocessors.register(FootnotePostprocessor(self), 'footnote', 25)
def reset(self) -> None:
""" Clear footnotes on reset, and prepare for distinct document. """
self.footnotes: OrderedDict[str, str] = OrderedDict()
self.unique_prefix += 1
self.found_refs = {}
self.used_refs = set()
def unique_ref(self, reference, found: bool = False):
""" Get a unique reference if there are duplicates. """
if not found:
return reference
original_ref = reference
while reference in self.used_refs:
ref, rest = reference.split(self.get_separator(), 1)
m = RE_REF_ID.match(ref)
if m:
reference = '%s%d%s%s' % (m.group(1), int(m.group(2))+1, self.get_separator(), rest)
else:
reference = '%s%d%s%s' % (ref, 2, self.get_separator(), rest)
self.used_refs.add(reference)
if original_ref in self.found_refs:
self.found_refs[original_ref] += 1
else:
self.found_refs[original_ref] = 1
return reference
def findFootnotesPlaceholder(self, root):
""" Return ElementTree Element that contains Footnote placeholder. """
def finder(element):
for child in element:
if child.text:
if child.text.find(self.getConfig("PLACE_MARKER")) > -1:
return child, element, True
if child.tail:
if child.tail.find(self.getConfig("PLACE_MARKER")) > -1:
return child, element, False
child_res = finder(child)
if child_res is not None:
return child_res
return None
res = finder(root)
return res
def setFootnote(self, id, text) -> None:
""" Store a footnote for later retrieval. """
self.footnotes[id] = text
def get_separator(self):
""" Get the footnote separator. """
return self.getConfig("SEPARATOR")
def makeFootnoteId(self, id):
""" Return footnote link id. """
if self.getConfig("UNIQUE_IDS"):
return 'fn%s%d-%s' % (self.get_separator(), self.unique_prefix, id)
else:
return 'fn{}{}'.format(self.get_separator(), id)
def makeFootnoteRefId(self, id, found: bool = False):
""" Return footnote back-link id. """
if self.getConfig("UNIQUE_IDS"):
return self.unique_ref('fnref%s%d-%s' % (self.get_separator(), self.unique_prefix, id), found)
else:
return self.unique_ref('fnref{}{}'.format(self.get_separator(), id), found)
def makeFootnotesDiv(self, root):
""" Return `div` of footnotes as `etree` Element. """
if not list(self.footnotes.keys()):
return None
div = etree.Element("div")
div.set('class', 'footnote')
etree.SubElement(div, "hr")
ol = etree.SubElement(div, "ol")
surrogate_parent = etree.Element("div")
# Backward compatibility with old '%d' placeholder
backlink_title = self.getConfig("BACKLINK_TITLE").replace("%d", "{}")
for index, id in enumerate(self.footnotes.keys(), start=1):
li = etree.SubElement(ol, "li")
li.set("id", self.makeFootnoteId(id))
# Parse footnote with surrogate parent as `li` cannot be used.
# List block handlers have special logic to deal with `li`.
# When we are done parsing, we will copy everything over to `li`.
self.parser.parseChunk(surrogate_parent, self.footnotes[id])
for el in list(surrogate_parent):
li.append(el)
surrogate_parent.remove(el)
backlink = etree.Element("a")
backlink.set("href", "#" + self.makeFootnoteRefId(id))
backlink.set("class", "footnote-backref")
backlink.set(
"title",
backlink_title.format(index)
)
backlink.text = FN_BACKLINK_TEXT
if len(li):
node = li[-1]
if node.tag == "p":
node.text = node.text + NBSP_PLACEHOLDER
node.append(backlink)
else:
p = etree.SubElement(li, "p")
p.append(backlink)
return div
class FootnoteBlockProcessor(BlockProcessor):
""" Find all footnote references and store for later use. """
RE = re.compile(r'^[ ]{0,3}\[\^([^\]]*)\]:[ ]*(.*)$', re.MULTILINE)
def __init__(self, footnotes):
super().__init__(footnotes.parser)
self.footnotes = footnotes
def test(self, parent, block):
return True
def run(self, parent, blocks):
""" Find, set, and remove footnote definitions. """
block = blocks.pop(0)
m = self.RE.search(block)
if m:
id = m.group(1)
fn_blocks = [m.group(2)]
# Handle rest of block
therest = block[m.end():].lstrip('\n')
m2 = self.RE.search(therest)
if m2:
# Another footnote exists in the rest of this block.
# Any content before match is continuation of this footnote, which may be lazily indented.
before = therest[:m2.start()].rstrip('\n')
fn_blocks[0] = '\n'.join([fn_blocks[0], self.detab(before)]).lstrip('\n')
# Add back to blocks everything from beginning of match forward for next iteration.
blocks.insert(0, therest[m2.start():])
else:
# All remaining lines of block are continuation of this footnote, which may be lazily indented.
fn_blocks[0] = '\n'.join([fn_blocks[0], self.detab(therest)]).strip('\n')
# Check for child elements in remaining blocks.
fn_blocks.extend(self.detectTabbed(blocks))
footnote = "\n\n".join(fn_blocks)
self.footnotes.setFootnote(id, footnote.rstrip())
if block[:m.start()].strip():
# Add any content before match back to blocks as separate block
blocks.insert(0, block[:m.start()].rstrip('\n'))
return True
# No match. Restore block.
blocks.insert(0, block)
return False
def detectTabbed(self, blocks) -> list[str]:
""" Find indented text and remove indent before further processing.
Returns:
A list of blocks with indentation removed.
"""
fn_blocks = []
while blocks:
if blocks[0].startswith(' '*4):
block = blocks.pop(0)
# Check for new footnotes within this block and split at new footnote.
m = self.RE.search(block)
if m:
# Another footnote exists in this block.
# Any content before match is continuation of this footnote, which may be lazily indented.
before = block[:m.start()].rstrip('\n')
fn_blocks.append(self.detab(before))
# Add back to blocks everything from beginning of match forward for next iteration.
blocks.insert(0, block[m.start():])
# End of this footnote.
break
else:
# Entire block is part of this footnote.
fn_blocks.append(self.detab(block))
else:
# End of this footnote.
break
return fn_blocks
def detab(self, block):
""" Remove one level of indent from a block.
Preserve lazily indented blocks by only removing indent from indented lines.
"""
lines = block.split('\n')
for i, line in enumerate(lines):
if line.startswith(' '*4):
lines[i] = line[4:]
return '\n'.join(lines)
class FootnoteInlineProcessor(InlineProcessor):
""" `InlineProcessor` for footnote markers in a document's body text. """
def __init__(self, pattern, footnotes):
super().__init__(pattern)
self.footnotes = footnotes
def handleMatch(self, m, data):
id = m.group(1)
if id in self.footnotes.footnotes.keys():
sup = etree.Element("sup")
a = etree.SubElement(sup, "a")
sup.set('id', self.footnotes.makeFootnoteRefId(id, found=True))
a.set('href', '#' + self.footnotes.makeFootnoteId(id))
a.set('class', 'footnote-ref')
a.text = self.footnotes.getConfig("SUPERSCRIPT_TEXT").format(
list(self.footnotes.footnotes.keys()).index(id) + 1
)
return sup, m.start(0), m.end(0)
else:
return None, None, None
class FootnotePostTreeprocessor(Treeprocessor):
""" Amend footnote div with duplicates. """
def __init__(self, footnotes):
self.footnotes = footnotes
def add_duplicates(self, li, duplicates) -> None:
""" Adjust current `li` and add the duplicates: `fnref2`, `fnref3`, etc. """
for link in li.iter('a'):
# Find the link that needs to be duplicated.
if link.attrib.get('class', '') == 'footnote-backref':
ref, rest = link.attrib['href'].split(self.footnotes.get_separator(), 1)
# Duplicate link the number of times we need to
# and point the to the appropriate references.
links = []
for index in range(2, duplicates + 1):
sib_link = copy.deepcopy(link)
sib_link.attrib['href'] = '%s%d%s%s' % (ref, index, self.footnotes.get_separator(), rest)
links.append(sib_link)
self.offset += 1
# Add all the new duplicate links.
el = list(li)[-1]
for link in links:
el.append(link)
break
def get_num_duplicates(self, li):
""" Get the number of duplicate refs of the footnote. """
fn, rest = li.attrib.get('id', '').split(self.footnotes.get_separator(), 1)
link_id = '{}ref{}{}'.format(fn, self.footnotes.get_separator(), rest)
return self.footnotes.found_refs.get(link_id, 0)
def handle_duplicates(self, parent) -> None:
""" Find duplicate footnotes and format and add the duplicates. """
for li in list(parent):
# Check number of duplicates footnotes and insert
# additional links if needed.
count = self.get_num_duplicates(li)
if count > 1:
self.add_duplicates(li, count)
def run(self, root):
""" Crawl the footnote div and add missing duplicate footnotes. """
self.offset = 0
for div in root.iter('div'):
if div.attrib.get('class', '') == 'footnote':
# Footnotes should be under the first ordered list under
# the footnote div. So once we find it, quit.
for ol in div.iter('ol'):
self.handle_duplicates(ol)
break
class FootnoteTreeprocessor(Treeprocessor):
""" Build and append footnote div to end of document. """
def __init__(self, footnotes):
self.footnotes = footnotes
def run(self, root):
footnotesDiv = self.footnotes.makeFootnotesDiv(root)
if footnotesDiv is not None:
result = self.footnotes.findFootnotesPlaceholder(root)
if result:
child, parent, isText = result
ind = list(parent).index(child)
if isText:
parent.remove(child)
parent.insert(ind, footnotesDiv)
else:
parent.insert(ind + 1, footnotesDiv)
child.tail = None
else:
root.append(footnotesDiv)
class FootnotePostprocessor(Postprocessor):
""" Replace placeholders with html entities. """
def __init__(self, footnotes):
self.footnotes = footnotes
def run(self, text):
text = text.replace(
FN_BACKLINK_TEXT, self.footnotes.getConfig("BACKLINK_TEXT")
)
return text.replace(NBSP_PLACEHOLDER, "&#160;")
def makeExtension(**kwargs): # pragma: no cover
""" Return an instance of the `FootnoteExtension` """
return FootnoteExtension(**kwargs)

View File

@@ -0,0 +1,67 @@
# Python Markdown
# A Python implementation of John Gruber's Markdown.
# Documentation: https://python-markdown.github.io/
# GitHub: https://github.com/Python-Markdown/markdown/
# PyPI: https://pypi.org/project/Markdown/
# Started by Manfred Stienstra (http://www.dwerg.net/).
# Maintained for a few years by Yuri Takhteyev (http://www.freewisdom.org).
# Currently maintained by Waylan Limberg (https://github.com/waylan),
# Dmitry Shachnev (https://github.com/mitya57) and Isaac Muse (https://github.com/facelessuser).
# Copyright 2007-2023 The Python Markdown Project (v. 1.7 and later)
# Copyright 2004, 2005, 2006 Yuri Takhteyev (v. 0.2-1.6b)
# Copyright 2004 Manfred Stienstra (the original version)
# License: BSD (see LICENSE.md for details).
"""
An extension to Python Markdown which implements legacy attributes.
Prior to Python-Markdown version 3.0, the Markdown class had an `enable_attributes`
keyword which was on by default and provided for attributes to be defined for elements
using the format `{@key=value}`. This extension is provided as a replacement for
backward compatibility. New documents should be authored using `attr_lists`. However,
numerous documents exist which have been using the old attribute format for many
years. This extension can be used to continue to render those documents correctly.
"""
from __future__ import annotations
import re
from markdown.treeprocessors import Treeprocessor, isString
from markdown.extensions import Extension
ATTR_RE = re.compile(r'\{@([^\}]*)=([^\}]*)}') # {@id=123}
class LegacyAttrs(Treeprocessor):
def run(self, doc):
"""Find and set values of attributes ({@key=value}). """
for el in doc.iter():
alt = el.get('alt', None)
if alt is not None:
el.set('alt', self.handleAttributes(el, alt))
if el.text and isString(el.text):
el.text = self.handleAttributes(el, el.text)
if el.tail and isString(el.tail):
el.tail = self.handleAttributes(el, el.tail)
def handleAttributes(self, el, txt):
""" Set attributes and return text without definitions. """
def attributeCallback(match):
el.set(match.group(1), match.group(2).replace('\n', ' '))
return ATTR_RE.sub(attributeCallback, txt)
class LegacyAttrExtension(Extension):
def extendMarkdown(self, md):
""" Add `LegacyAttrs` to Markdown instance. """
md.treeprocessors.register(LegacyAttrs(md), 'legacyattrs', 15)
def makeExtension(**kwargs): # pragma: no cover
return LegacyAttrExtension(**kwargs)

View File

@@ -0,0 +1,52 @@
# Legacy Em Extension for Python-Markdown
# =======================================
# This extension provides legacy behavior for _connected_words_.
# Copyright 2015-2018 The Python Markdown Project
# License: [BSD](https://opensource.org/licenses/bsd-license.php)
"""
This extension provides legacy behavior for _connected_words_.
"""
from __future__ import annotations
from . import Extension
from ..inlinepatterns import UnderscoreProcessor, EmStrongItem, EM_STRONG2_RE, STRONG_EM2_RE
import re
# _emphasis_
EMPHASIS_RE = r'(_)([^_]+)\1'
# __strong__
STRONG_RE = r'(_{2})(.+?)\1'
# __strong_em___
STRONG_EM_RE = r'(_)\1(?!\1)([^_]+?)\1(?!\1)(.+?)\1{3}'
class LegacyUnderscoreProcessor(UnderscoreProcessor):
"""Emphasis processor for handling strong and em matches inside underscores."""
PATTERNS = [
EmStrongItem(re.compile(EM_STRONG2_RE, re.DOTALL | re.UNICODE), 'double', 'strong,em'),
EmStrongItem(re.compile(STRONG_EM2_RE, re.DOTALL | re.UNICODE), 'double', 'em,strong'),
EmStrongItem(re.compile(STRONG_EM_RE, re.DOTALL | re.UNICODE), 'double2', 'strong,em'),
EmStrongItem(re.compile(STRONG_RE, re.DOTALL | re.UNICODE), 'single', 'strong'),
EmStrongItem(re.compile(EMPHASIS_RE, re.DOTALL | re.UNICODE), 'single', 'em')
]
class LegacyEmExtension(Extension):
""" Add legacy_em extension to Markdown class."""
def extendMarkdown(self, md):
""" Modify inline patterns. """
md.inlinePatterns.register(LegacyUnderscoreProcessor(r'_'), 'em_strong2', 50)
def makeExtension(**kwargs): # pragma: no cover
""" Return an instance of the `LegacyEmExtension` """
return LegacyEmExtension(**kwargs)

View File

@@ -0,0 +1,372 @@
# Python-Markdown Markdown in HTML Extension
# ===============================
# An implementation of [PHP Markdown Extra](http://michelf.com/projects/php-markdown/extra/)'s
# parsing of Markdown syntax in raw HTML.
# See https://Python-Markdown.github.io/extensions/raw_html
# for documentation.
# Copyright The Python Markdown Project
# License: [BSD](https://opensource.org/licenses/bsd-license.php)
"""
An implementation of [PHP Markdown Extra](http://michelf.com/projects/php-markdown/extra/)'s
parsing of Markdown syntax in raw HTML.
See the [documentation](https://Python-Markdown.github.io/extensions/raw_html)
for details.
"""
from __future__ import annotations
from . import Extension
from ..blockprocessors import BlockProcessor
from ..preprocessors import Preprocessor
from ..postprocessors import RawHtmlPostprocessor
from .. import util
from ..htmlparser import HTMLExtractor, blank_line_re
import xml.etree.ElementTree as etree
class HTMLExtractorExtra(HTMLExtractor):
"""
Override `HTMLExtractor` and create `etree` `Elements` for any elements which should have content parsed as
Markdown.
"""
def __init__(self, md, *args, **kwargs):
# All block-level tags.
self.block_level_tags = set(md.block_level_elements.copy())
# Block-level tags in which the content only gets span level parsing
self.span_tags = set(
['address', 'dd', 'dt', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6', 'legend', 'li', 'p', 'summary', 'td', 'th']
)
# Block-level tags which never get their content parsed.
self.raw_tags = set(['canvas', 'math', 'option', 'pre', 'script', 'style', 'textarea'])
super().__init__(md, *args, **kwargs)
# Block-level tags in which the content gets parsed as blocks
self.block_tags = set(self.block_level_tags) - (self.span_tags | self.raw_tags | self.empty_tags)
self.span_and_blocks_tags = self.block_tags | self.span_tags
def reset(self):
"""Reset this instance. Loses all unprocessed data."""
self.mdstack = [] # When markdown=1, stack contains a list of tags
self.treebuilder = etree.TreeBuilder()
self.mdstate = [] # one of 'block', 'span', 'off', or None
super().reset()
def close(self):
"""Handle any buffered data."""
super().close()
# Handle any unclosed tags.
if self.mdstack:
# Close the outermost parent. `handle_endtag` will close all unclosed children.
self.handle_endtag(self.mdstack[0])
def get_element(self):
""" Return element from `treebuilder` and reset `treebuilder` for later use. """
element = self.treebuilder.close()
self.treebuilder = etree.TreeBuilder()
return element
def get_state(self, tag, attrs):
""" Return state from tag and `markdown` attribute. One of 'block', 'span', or 'off'. """
md_attr = attrs.get('markdown', '0')
if md_attr == 'markdown':
# `<tag markdown>` is the same as `<tag markdown='1'>`.
md_attr = '1'
parent_state = self.mdstate[-1] if self.mdstate else None
if parent_state == 'off' or (parent_state == 'span' and md_attr != '0'):
# Only use the parent state if it is more restrictive than the markdown attribute.
md_attr = parent_state
if ((md_attr == '1' and tag in self.block_tags) or
(md_attr == 'block' and tag in self.span_and_blocks_tags)):
return 'block'
elif ((md_attr == '1' and tag in self.span_tags) or
(md_attr == 'span' and tag in self.span_and_blocks_tags)):
return 'span'
elif tag in self.block_level_tags:
return 'off'
else: # pragma: no cover
return None
def handle_starttag(self, tag, attrs):
# Handle tags that should always be empty and do not specify a closing tag
if tag in self.empty_tags and (self.at_line_start() or self.intail):
attrs = {key: value if value is not None else key for key, value in attrs}
if "markdown" in attrs:
attrs.pop('markdown')
element = etree.Element(tag, attrs)
data = etree.tostring(element, encoding='unicode', method='html')
else:
data = self.get_starttag_text()
self.handle_empty_tag(data, True)
return
if tag in self.block_level_tags and (self.at_line_start() or self.intail):
# Valueless attribute (ex: `<tag checked>`) results in `[('checked', None)]`.
# Convert to `{'checked': 'checked'}`.
attrs = {key: value if value is not None else key for key, value in attrs}
state = self.get_state(tag, attrs)
if self.inraw or (state in [None, 'off'] and not self.mdstack):
# fall back to default behavior
attrs.pop('markdown', None)
super().handle_starttag(tag, attrs)
else:
if 'p' in self.mdstack and tag in self.block_level_tags:
# Close unclosed 'p' tag
self.handle_endtag('p')
self.mdstate.append(state)
self.mdstack.append(tag)
attrs['markdown'] = state
self.treebuilder.start(tag, attrs)
else:
# Span level tag
if self.inraw:
super().handle_starttag(tag, attrs)
else:
text = self.get_starttag_text()
if self.mdstate and self.mdstate[-1] == "off":
self.handle_data(self.md.htmlStash.store(text))
else:
self.handle_data(text)
if tag in self.CDATA_CONTENT_ELEMENTS:
# This is presumably a standalone tag in a code span (see #1036).
self.clear_cdata_mode()
def handle_endtag(self, tag):
if tag in self.block_level_tags:
if self.inraw:
super().handle_endtag(tag)
elif tag in self.mdstack:
# Close element and any unclosed children
while self.mdstack:
item = self.mdstack.pop()
self.mdstate.pop()
self.treebuilder.end(item)
if item == tag:
break
if not self.mdstack:
# Last item in stack is closed. Stash it
element = self.get_element()
# Get last entry to see if it ends in newlines
# If it is an element, assume there is no newlines
item = self.cleandoc[-1] if self.cleandoc else ''
# If we only have one newline before block element, add another
if not item.endswith('\n\n') and item.endswith('\n'):
self.cleandoc.append('\n')
self.cleandoc.append(self.md.htmlStash.store(element))
self.cleandoc.append('\n\n')
self.state = []
# Check if element has a tail
if not blank_line_re.match(
self.rawdata[self.line_offset + self.offset + len(self.get_endtag_text(tag)):]):
# More content exists after `endtag`.
self.intail = True
else:
# Treat orphan closing tag as a span level tag.
text = self.get_endtag_text(tag)
if self.mdstate and self.mdstate[-1] == "off":
self.handle_data(self.md.htmlStash.store(text))
else:
self.handle_data(text)
else:
# Span level tag
if self.inraw:
super().handle_endtag(tag)
else:
text = self.get_endtag_text(tag)
if self.mdstate and self.mdstate[-1] == "off":
self.handle_data(self.md.htmlStash.store(text))
else:
self.handle_data(text)
def handle_startendtag(self, tag, attrs):
if tag in self.empty_tags:
attrs = {key: value if value is not None else key for key, value in attrs}
if "markdown" in attrs:
attrs.pop('markdown')
element = etree.Element(tag, attrs)
data = etree.tostring(element, encoding='unicode', method='html')
else:
data = self.get_starttag_text()
else:
data = self.get_starttag_text()
self.handle_empty_tag(data, is_block=self.md.is_block_level(tag))
def handle_data(self, data):
if self.intail and '\n' in data:
self.intail = False
if self.inraw or not self.mdstack:
super().handle_data(data)
else:
self.treebuilder.data(data)
def handle_empty_tag(self, data, is_block):
if self.inraw or not self.mdstack:
super().handle_empty_tag(data, is_block)
else:
if self.at_line_start() and is_block:
self.handle_data('\n' + self.md.htmlStash.store(data) + '\n\n')
else:
self.handle_data(self.md.htmlStash.store(data))
def parse_pi(self, i):
if self.at_line_start() or self.intail or self.mdstack:
# The same override exists in `HTMLExtractor` without the check
# for `mdstack`. Therefore, use parent of `HTMLExtractor` instead.
return super(HTMLExtractor, self).parse_pi(i)
# This is not the beginning of a raw block so treat as plain data
# and avoid consuming any tags which may follow (see #1066).
self.handle_data('<?')
return i + 2
def parse_html_declaration(self, i):
if self.at_line_start() or self.intail or self.mdstack:
# The same override exists in `HTMLExtractor` without the check
# for `mdstack`. Therefore, use parent of `HTMLExtractor` instead.
return super(HTMLExtractor, self).parse_html_declaration(i)
# This is not the beginning of a raw block so treat as plain data
# and avoid consuming any tags which may follow (see #1066).
self.handle_data('<!')
return i + 2
class HtmlBlockPreprocessor(Preprocessor):
"""Remove html blocks from the text and store them for later retrieval."""
def run(self, lines):
source = '\n'.join(lines)
parser = HTMLExtractorExtra(self.md)
parser.feed(source)
parser.close()
return ''.join(parser.cleandoc).split('\n')
class MarkdownInHtmlProcessor(BlockProcessor):
"""Process Markdown Inside HTML Blocks which have been stored in the `HtmlStash`."""
def test(self, parent, block):
# Always return True. `run` will return `False` it not a valid match.
return True
def parse_element_content(self, element):
"""
Recursively parse the text content of an `etree` Element as Markdown.
Any block level elements generated from the Markdown will be inserted as children of the element in place
of the text content. All `markdown` attributes are removed. For any elements in which Markdown parsing has
been disabled, the text content of it and its children are wrapped in an `AtomicString`.
"""
md_attr = element.attrib.pop('markdown', 'off')
if md_attr == 'block':
# Parse content as block level
# The order in which the different parts are parsed (text, children, tails) is important here as the
# order of elements needs to be preserved. We can't be inserting items at a later point in the current
# iteration as we don't want to do raw processing on elements created from parsing Markdown text (for
# example). Therefore, the order of operations is children, tails, text.
# Recursively parse existing children from raw HTML
for child in list(element):
self.parse_element_content(child)
# Parse Markdown text in tail of children. Do this separate to avoid raw HTML parsing.
# Save the position of each item to be inserted later in reverse.
tails = []
for pos, child in enumerate(element):
if child.tail:
block = child.tail.rstrip('\n')
child.tail = ''
# Use a dummy placeholder element.
dummy = etree.Element('div')
self.parser.parseBlocks(dummy, block.split('\n\n'))
children = list(dummy)
children.reverse()
tails.append((pos + 1, children))
# Insert the elements created from the tails in reverse.
tails.reverse()
for pos, tail in tails:
for item in tail:
element.insert(pos, item)
# Parse Markdown text content. Do this last to avoid raw HTML parsing.
if element.text:
block = element.text.rstrip('\n')
element.text = ''
# Use a dummy placeholder element as the content needs to get inserted before existing children.
dummy = etree.Element('div')
self.parser.parseBlocks(dummy, block.split('\n\n'))
children = list(dummy)
children.reverse()
for child in children:
element.insert(0, child)
elif md_attr == 'span':
# Span level parsing will be handled by inline processors.
# Walk children here to remove any `markdown` attributes.
for child in list(element):
self.parse_element_content(child)
else:
# Disable inline parsing for everything else
if element.text is None:
element.text = ''
element.text = util.AtomicString(element.text)
for child in list(element):
self.parse_element_content(child)
if child.tail:
child.tail = util.AtomicString(child.tail)
def run(self, parent, blocks):
m = util.HTML_PLACEHOLDER_RE.match(blocks[0])
if m:
index = int(m.group(1))
element = self.parser.md.htmlStash.rawHtmlBlocks[index]
if isinstance(element, etree.Element):
# We have a matched element. Process it.
blocks.pop(0)
self.parse_element_content(element)
parent.append(element)
# Cleanup stash. Replace element with empty string to avoid confusing postprocessor.
self.parser.md.htmlStash.rawHtmlBlocks.pop(index)
self.parser.md.htmlStash.rawHtmlBlocks.insert(index, '')
# Confirm the match to the `blockparser`.
return True
# No match found.
return False
class MarkdownInHTMLPostprocessor(RawHtmlPostprocessor):
def stash_to_string(self, text):
""" Override default to handle any `etree` elements still in the stash. """
if isinstance(text, etree.Element):
return self.md.serializer(text)
else:
return str(text)
class MarkdownInHtmlExtension(Extension):
"""Add Markdown parsing in HTML to Markdown class."""
def extendMarkdown(self, md):
""" Register extension instances. """
# Replace raw HTML preprocessor
md.preprocessors.register(HtmlBlockPreprocessor(md), 'html_block', 20)
# Add `blockprocessor` which handles the placeholders for `etree` elements
md.parser.blockprocessors.register(
MarkdownInHtmlProcessor(md.parser), 'markdown_block', 105
)
# Replace raw HTML postprocessor
md.postprocessors.register(MarkdownInHTMLPostprocessor(md), 'raw_html', 30)
def makeExtension(**kwargs): # pragma: no cover
return MarkdownInHtmlExtension(**kwargs)

View File

@@ -0,0 +1,85 @@
# Meta Data Extension for Python-Markdown
# =======================================
# This extension adds Meta Data handling to markdown.
# See https://Python-Markdown.github.io/extensions/meta_data
# for documentation.
# Original code Copyright 2007-2008 [Waylan Limberg](http://achinghead.com).
# All changes Copyright 2008-2014 The Python Markdown Project
# License: [BSD](https://opensource.org/licenses/bsd-license.php)
"""
This extension adds Meta Data handling to markdown.
See the [documentation](https://Python-Markdown.github.io/extensions/meta_data)
for details.
"""
from __future__ import annotations
from . import Extension
from ..preprocessors import Preprocessor
import re
import logging
log = logging.getLogger('MARKDOWN')
# Global Vars
META_RE = re.compile(r'^[ ]{0,3}(?P<key>[A-Za-z0-9_-]+):\s*(?P<value>.*)')
META_MORE_RE = re.compile(r'^[ ]{4,}(?P<value>.*)')
BEGIN_RE = re.compile(r'^-{3}(\s.*)?')
END_RE = re.compile(r'^(-{3}|\.{3})(\s.*)?')
class MetaExtension (Extension):
""" Meta-Data extension for Python-Markdown. """
def extendMarkdown(self, md):
""" Add `MetaPreprocessor` to Markdown instance. """
md.registerExtension(self)
self.md = md
md.preprocessors.register(MetaPreprocessor(md), 'meta', 27)
def reset(self) -> None:
self.md.Meta = {}
class MetaPreprocessor(Preprocessor):
""" Get Meta-Data. """
def run(self, lines):
""" Parse Meta-Data and store in Markdown.Meta. """
meta = {}
key = None
if lines and BEGIN_RE.match(lines[0]):
lines.pop(0)
while lines:
line = lines.pop(0)
m1 = META_RE.match(line)
if line.strip() == '' or END_RE.match(line):
break # blank line or end of YAML header - done
if m1:
key = m1.group('key').lower().strip()
value = m1.group('value').strip()
try:
meta[key].append(value)
except KeyError:
meta[key] = [value]
else:
m2 = META_MORE_RE.match(line)
if m2 and key:
# Add another line to existing key
meta[key].append(m2.group('value').strip())
else:
lines.insert(0, line)
break # no meta data - done
self.md.Meta = meta
return lines
def makeExtension(**kwargs): # pragma: no cover
return MetaExtension(**kwargs)

View File

@@ -0,0 +1,41 @@
# `NL2BR` Extension
# ===============
# A Python-Markdown extension to treat newlines as hard breaks; like
# GitHub-flavored Markdown does.
# See https://Python-Markdown.github.io/extensions/nl2br
# for documentation.
# Original code Copyright 2011 [Brian Neal](https://deathofagremmie.com/)
# All changes Copyright 2011-2014 The Python Markdown Project
# License: [BSD](https://opensource.org/licenses/bsd-license.php)
"""
A Python-Markdown extension to treat newlines as hard breaks; like
GitHub-flavored Markdown does.
See the [documentation](https://Python-Markdown.github.io/extensions/nl2br)
for details.
"""
from __future__ import annotations
from . import Extension
from ..inlinepatterns import SubstituteTagInlineProcessor
BR_RE = r'\n'
class Nl2BrExtension(Extension):
def extendMarkdown(self, md):
""" Add a `SubstituteTagInlineProcessor` to Markdown. """
br_tag = SubstituteTagInlineProcessor(BR_RE, 'br')
md.inlinePatterns.register(br_tag, 'nl', 5)
def makeExtension(**kwargs): # pragma: no cover
return Nl2BrExtension(**kwargs)

View File

@@ -0,0 +1,65 @@
# Sane List Extension for Python-Markdown
# =======================================
# Modify the behavior of Lists in Python-Markdown to act in a sane manor.
# See https://Python-Markdown.github.io/extensions/sane_lists
# for documentation.
# Original code Copyright 2011 [Waylan Limberg](http://achinghead.com)
# All changes Copyright 2011-2014 The Python Markdown Project
# License: [BSD](https://opensource.org/licenses/bsd-license.php)
"""
Modify the behavior of Lists in Python-Markdown to act in a sane manor.
See [documentation](https://Python-Markdown.github.io/extensions/sane_lists)
for details.
"""
from __future__ import annotations
from . import Extension
from ..blockprocessors import OListProcessor, UListProcessor
import re
class SaneOListProcessor(OListProcessor):
""" Override `SIBLING_TAGS` to not include `ul` and set `LAZY_OL` to `False`. """
SIBLING_TAGS = ['ol']
""" Exclude `ul` from list of siblings. """
LAZY_OL = False
""" Disable lazy list behavior. """
def __init__(self, parser):
super().__init__(parser)
self.CHILD_RE = re.compile(r'^[ ]{0,%d}((\d+\.))[ ]+(.*)' %
(self.tab_length - 1))
class SaneUListProcessor(UListProcessor):
""" Override `SIBLING_TAGS` to not include `ol`. """
SIBLING_TAGS = ['ul']
""" Exclude `ol` from list of siblings. """
def __init__(self, parser):
super().__init__(parser)
self.CHILD_RE = re.compile(r'^[ ]{0,%d}(([*+-]))[ ]+(.*)' %
(self.tab_length - 1))
class SaneListExtension(Extension):
""" Add sane lists to Markdown. """
def extendMarkdown(self, md):
""" Override existing Processors. """
md.parser.blockprocessors.register(SaneOListProcessor(md.parser), 'olist', 40)
md.parser.blockprocessors.register(SaneUListProcessor(md.parser), 'ulist', 30)
def makeExtension(**kwargs): # pragma: no cover
return SaneListExtension(**kwargs)

View File

@@ -0,0 +1,265 @@
# Smarty extension for Python-Markdown
# ====================================
# Adds conversion of ASCII dashes, quotes and ellipses to their HTML
# entity equivalents.
# See https://Python-Markdown.github.io/extensions/smarty
# for documentation.
# Author: 2013, Dmitry Shachnev <mitya57@gmail.com>
# All changes Copyright 2013-2014 The Python Markdown Project
# License: [BSD](https://opensource.org/licenses/bsd-license.php)
# SmartyPants license:
# Copyright (c) 2003 John Gruber <https://daringfireball.net/>
# All rights reserved.
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in
# the documentation and/or other materials provided with the
# distribution.
# * Neither the name "SmartyPants" nor the names of its contributors
# may be used to endorse or promote products derived from this
# software without specific prior written permission.
# This software is provided by the copyright holders and contributors "as
# is" and any express or implied warranties, including, but not limited
# to, the implied warranties of merchantability and fitness for a
# particular purpose are disclaimed. In no event shall the copyright
# owner or contributors be liable for any direct, indirect, incidental,
# special, exemplary, or consequential damages (including, but not
# limited to, procurement of substitute goods or services; loss of use,
# data, or profits; or business interruption) however caused and on any
# theory of liability, whether in contract, strict liability, or tort
# (including negligence or otherwise) arising in any way out of the use
# of this software, even if advised of the possibility of such damage.
# `smartypants.py` license:
# `smartypants.py` is a derivative work of SmartyPants.
# Copyright (c) 2004, 2007 Chad Miller <http://web.chad.org/>
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in
# the documentation and/or other materials provided with the
# distribution.
# This software is provided by the copyright holders and contributors "as
# is" and any express or implied warranties, including, but not limited
# to, the implied warranties of merchantability and fitness for a
# particular purpose are disclaimed. In no event shall the copyright
# owner or contributors be liable for any direct, indirect, incidental,
# special, exemplary, or consequential damages (including, but not
# limited to, procurement of substitute goods or services; loss of use,
# data, or profits; or business interruption) however caused and on any
# theory of liability, whether in contract, strict liability, or tort
# (including negligence or otherwise) arising in any way out of the use
# of this software, even if advised of the possibility of such damage.
"""
Adds conversion of ASCII dashes, quotes and ellipses to their HTML
entity equivalents.
See the [documentation](https://Python-Markdown.github.io/extensions/smarty)
for details.
"""
from __future__ import annotations
from . import Extension
from ..inlinepatterns import HtmlInlineProcessor, HTML_RE
from ..treeprocessors import InlineProcessor
from ..util import Registry
# Constants for quote education.
punctClass = r"""[!"#\$\%'()*+,-.\/:;<=>?\@\[\\\]\^_`{|}~]"""
endOfWordClass = r"[\s.,;:!?)]"
closeClass = r"[^\ \t\r\n\[\{\(\-\u0002\u0003]"
openingQuotesBase = (
r'(\s' # a whitespace char
r'|&nbsp;' # or a non-breaking space entity
r'|--' # or dashes
r'||—' # or Unicode
r'|&[mn]dash;' # or named dash entities
r'|&#8211;|&#8212;' # or decimal entities
r')'
)
substitutions = {
'mdash': '&mdash;',
'ndash': '&ndash;',
'ellipsis': '&hellip;',
'left-angle-quote': '&laquo;',
'right-angle-quote': '&raquo;',
'left-single-quote': '&lsquo;',
'right-single-quote': '&rsquo;',
'left-double-quote': '&ldquo;',
'right-double-quote': '&rdquo;',
}
# Special case if the very first character is a quote
# followed by punctuation at a non-word-break. Close the quotes by brute force:
singleQuoteStartRe = r"^'(?=%s\B)" % punctClass
doubleQuoteStartRe = r'^"(?=%s\B)' % punctClass
# Special case for double sets of quotes, e.g.:
# <p>He said, "'Quoted' words in a larger quote."</p>
doubleQuoteSetsRe = r""""'(?=\w)"""
singleQuoteSetsRe = r"""'"(?=\w)"""
# Special case for decade abbreviations (the '80s):
decadeAbbrRe = r"(?<!\w)'(?=\d{2}s)"
# Get most opening double quotes:
openingDoubleQuotesRegex = r'%s"(?=\w)' % openingQuotesBase
# Double closing quotes:
closingDoubleQuotesRegex = r'"(?=\s)'
closingDoubleQuotesRegex2 = '(?<=%s)"' % closeClass
# Get most opening single quotes:
openingSingleQuotesRegex = r"%s'(?=\w)" % openingQuotesBase
# Single closing quotes:
closingSingleQuotesRegex = r"(?<=%s)'(?!\s|s\b|\d)" % closeClass
closingSingleQuotesRegex2 = r"'(\s|s\b)"
# All remaining quotes should be opening ones
remainingSingleQuotesRegex = r"'"
remainingDoubleQuotesRegex = r'"'
HTML_STRICT_RE = HTML_RE + r'(?!\>)'
class SubstituteTextPattern(HtmlInlineProcessor):
def __init__(self, pattern, replace, md):
""" Replaces matches with some text. """
HtmlInlineProcessor.__init__(self, pattern)
self.replace = replace
self.md = md
def handleMatch(self, m, data):
result = ''
for part in self.replace:
if isinstance(part, int):
result += m.group(part)
else:
result += self.md.htmlStash.store(part)
return result, m.start(0), m.end(0)
class SmartyExtension(Extension):
""" Add Smarty to Markdown. """
def __init__(self, **kwargs):
self.config = {
'smart_quotes': [True, 'Educate quotes'],
'smart_angled_quotes': [False, 'Educate angled quotes'],
'smart_dashes': [True, 'Educate dashes'],
'smart_ellipses': [True, 'Educate ellipses'],
'substitutions': [{}, 'Overwrite default substitutions'],
}
""" Default configuration options. """
super().__init__(**kwargs)
self.substitutions = dict(substitutions)
self.substitutions.update(self.getConfig('substitutions', default={}))
def _addPatterns(self, md, patterns, serie, priority):
for ind, pattern in enumerate(patterns):
pattern += (md,)
pattern = SubstituteTextPattern(*pattern)
name = 'smarty-%s-%d' % (serie, ind)
self.inlinePatterns.register(pattern, name, priority-ind)
def educateDashes(self, md) -> None:
emDashesPattern = SubstituteTextPattern(
r'(?<!-)---(?!-)', (self.substitutions['mdash'],), md
)
enDashesPattern = SubstituteTextPattern(
r'(?<!-)--(?!-)', (self.substitutions['ndash'],), md
)
self.inlinePatterns.register(emDashesPattern, 'smarty-em-dashes', 50)
self.inlinePatterns.register(enDashesPattern, 'smarty-en-dashes', 45)
def educateEllipses(self, md) -> None:
ellipsesPattern = SubstituteTextPattern(
r'(?<!\.)\.{3}(?!\.)', (self.substitutions['ellipsis'],), md
)
self.inlinePatterns.register(ellipsesPattern, 'smarty-ellipses', 10)
def educateAngledQuotes(self, md) -> None:
leftAngledQuotePattern = SubstituteTextPattern(
r'\<\<', (self.substitutions['left-angle-quote'],), md
)
rightAngledQuotePattern = SubstituteTextPattern(
r'\>\>', (self.substitutions['right-angle-quote'],), md
)
self.inlinePatterns.register(leftAngledQuotePattern, 'smarty-left-angle-quotes', 40)
self.inlinePatterns.register(rightAngledQuotePattern, 'smarty-right-angle-quotes', 35)
def educateQuotes(self, md) -> None:
lsquo = self.substitutions['left-single-quote']
rsquo = self.substitutions['right-single-quote']
ldquo = self.substitutions['left-double-quote']
rdquo = self.substitutions['right-double-quote']
patterns = (
(singleQuoteStartRe, (rsquo,)),
(doubleQuoteStartRe, (rdquo,)),
(doubleQuoteSetsRe, (ldquo + lsquo,)),
(singleQuoteSetsRe, (lsquo + ldquo,)),
(decadeAbbrRe, (rsquo,)),
(openingSingleQuotesRegex, (1, lsquo)),
(closingSingleQuotesRegex, (rsquo,)),
(closingSingleQuotesRegex2, (rsquo, 1)),
(remainingSingleQuotesRegex, (lsquo,)),
(openingDoubleQuotesRegex, (1, ldquo)),
(closingDoubleQuotesRegex, (rdquo,)),
(closingDoubleQuotesRegex2, (rdquo,)),
(remainingDoubleQuotesRegex, (ldquo,))
)
self._addPatterns(md, patterns, 'quotes', 30)
def extendMarkdown(self, md):
configs = self.getConfigs()
self.inlinePatterns: Registry[HtmlInlineProcessor] = Registry()
if configs['smart_ellipses']:
self.educateEllipses(md)
if configs['smart_quotes']:
self.educateQuotes(md)
if configs['smart_angled_quotes']:
self.educateAngledQuotes(md)
# Override `HTML_RE` from `inlinepatterns.py` so that it does not
# process tags with duplicate closing quotes.
md.inlinePatterns.register(HtmlInlineProcessor(HTML_STRICT_RE, md), 'html', 90)
if configs['smart_dashes']:
self.educateDashes(md)
inlineProcessor = InlineProcessor(md)
inlineProcessor.inlinePatterns = self.inlinePatterns
md.treeprocessors.register(inlineProcessor, 'smarty', 2)
md.ESCAPED_CHARS.extend(['"', "'"])
def makeExtension(**kwargs): # pragma: no cover
return SmartyExtension(**kwargs)

View File

@@ -0,0 +1,243 @@
# Tables Extension for Python-Markdown
# ====================================
# Added parsing of tables to Python-Markdown.
# See https://Python-Markdown.github.io/extensions/tables
# for documentation.
# Original code Copyright 2009 [Waylan Limberg](http://achinghead.com)
# All changes Copyright 2008-2014 The Python Markdown Project
# License: [BSD](https://opensource.org/licenses/bsd-license.php)
"""
Added parsing of tables to Python-Markdown.
See the [documentation](https://Python-Markdown.github.io/extensions/tables)
for details.
"""
from __future__ import annotations
from . import Extension
from ..blockprocessors import BlockProcessor
import xml.etree.ElementTree as etree
import re
PIPE_NONE = 0
PIPE_LEFT = 1
PIPE_RIGHT = 2
class TableProcessor(BlockProcessor):
""" Process Tables. """
RE_CODE_PIPES = re.compile(r'(?:(\\\\)|(\\`+)|(`+)|(\\\|)|(\|))')
RE_END_BORDER = re.compile(r'(?<!\\)(?:\\\\)*\|$')
def __init__(self, parser, config):
self.border = False
self.separator = ''
self.config = config
super().__init__(parser)
def test(self, parent, block):
"""
Ensure first two rows (column header and separator row) are valid table rows.
Keep border check and separator row do avoid repeating the work.
"""
is_table = False
rows = [row.strip(' ') for row in block.split('\n')]
if len(rows) > 1:
header0 = rows[0]
self.border = PIPE_NONE
if header0.startswith('|'):
self.border |= PIPE_LEFT
if self.RE_END_BORDER.search(header0) is not None:
self.border |= PIPE_RIGHT
row = self._split_row(header0)
row0_len = len(row)
is_table = row0_len > 1
# Each row in a single column table needs at least one pipe.
if not is_table and row0_len == 1 and self.border:
for index in range(1, len(rows)):
is_table = rows[index].startswith('|')
if not is_table:
is_table = self.RE_END_BORDER.search(rows[index]) is not None
if not is_table:
break
if is_table:
row = self._split_row(rows[1])
is_table = (len(row) == row0_len) and set(''.join(row)) <= set('|:- ')
if is_table:
self.separator = row
return is_table
def run(self, parent, blocks):
""" Parse a table block and build table. """
block = blocks.pop(0).split('\n')
header = block[0].strip(' ')
rows = [] if len(block) < 3 else block[2:]
# Get alignment of columns
align = []
for c in self.separator:
c = c.strip(' ')
if c.startswith(':') and c.endswith(':'):
align.append('center')
elif c.startswith(':'):
align.append('left')
elif c.endswith(':'):
align.append('right')
else:
align.append(None)
# Build table
table = etree.SubElement(parent, 'table')
thead = etree.SubElement(table, 'thead')
self._build_row(header, thead, align)
tbody = etree.SubElement(table, 'tbody')
if len(rows) == 0:
# Handle empty table
self._build_empty_row(tbody, align)
else:
for row in rows:
self._build_row(row.strip(' '), tbody, align)
def _build_empty_row(self, parent, align):
"""Build an empty row."""
tr = etree.SubElement(parent, 'tr')
count = len(align)
while count:
etree.SubElement(tr, 'td')
count -= 1
def _build_row(self, row, parent, align):
""" Given a row of text, build table cells. """
tr = etree.SubElement(parent, 'tr')
tag = 'td'
if parent.tag == 'thead':
tag = 'th'
cells = self._split_row(row)
# We use align here rather than cells to ensure every row
# contains the same number of columns.
for i, a in enumerate(align):
c = etree.SubElement(tr, tag)
try:
c.text = cells[i].strip(' ')
except IndexError: # pragma: no cover
c.text = ""
if a:
if self.config['use_align_attribute']:
c.set('align', a)
else:
c.set('style', f'text-align: {a};')
def _split_row(self, row):
""" split a row of text into list of cells. """
if self.border:
if row.startswith('|'):
row = row[1:]
row = self.RE_END_BORDER.sub('', row)
return self._split(row)
def _split(self, row):
""" split a row of text with some code into a list of cells. """
elements = []
pipes = []
tics = []
tic_points = []
tic_region = []
good_pipes = []
# Parse row
# Throw out \\, and \|
for m in self.RE_CODE_PIPES.finditer(row):
# Store ` data (len, start_pos, end_pos)
if m.group(2):
# \`+
# Store length of each tic group: subtract \
tics.append(len(m.group(2)) - 1)
# Store start of group, end of group, and escape length
tic_points.append((m.start(2), m.end(2) - 1, 1))
elif m.group(3):
# `+
# Store length of each tic group
tics.append(len(m.group(3)))
# Store start of group, end of group, and escape length
tic_points.append((m.start(3), m.end(3) - 1, 0))
# Store pipe location
elif m.group(5):
pipes.append(m.start(5))
# Pair up tics according to size if possible
# Subtract the escape length *only* from the opening.
# Walk through tic list and see if tic has a close.
# Store the tic region (start of region, end of region).
pos = 0
tic_len = len(tics)
while pos < tic_len:
try:
tic_size = tics[pos] - tic_points[pos][2]
if tic_size == 0:
raise ValueError
index = tics[pos + 1:].index(tic_size) + 1
tic_region.append((tic_points[pos][0], tic_points[pos + index][1]))
pos += index + 1
except ValueError:
pos += 1
# Resolve pipes. Check if they are within a tic pair region.
# Walk through pipes comparing them to each region.
# - If pipe position is less that a region, it isn't in a region
# - If it is within a region, we don't want it, so throw it out
# - If we didn't throw it out, it must be a table pipe
for pipe in pipes:
throw_out = False
for region in tic_region:
if pipe < region[0]:
# Pipe is not in a region
break
elif region[0] <= pipe <= region[1]:
# Pipe is within a code region. Throw it out.
throw_out = True
break
if not throw_out:
good_pipes.append(pipe)
# Split row according to table delimiters.
pos = 0
for pipe in good_pipes:
elements.append(row[pos:pipe])
pos = pipe + 1
elements.append(row[pos:])
return elements
class TableExtension(Extension):
""" Add tables to Markdown. """
def __init__(self, **kwargs):
self.config = {
'use_align_attribute': [False, 'True to use align attribute instead of style.'],
}
""" Default configuration options. """
super().__init__(**kwargs)
def extendMarkdown(self, md):
""" Add an instance of `TableProcessor` to `BlockParser`. """
if '|' not in md.ESCAPED_CHARS:
md.ESCAPED_CHARS.append('|')
processor = TableProcessor(md.parser, self.getConfigs())
md.parser.blockprocessors.register(processor, 'table', 75)
def makeExtension(**kwargs): # pragma: no cover
return TableExtension(**kwargs)

View File

@@ -0,0 +1,408 @@
# Table of Contents Extension for Python-Markdown
# ===============================================
# See https://Python-Markdown.github.io/extensions/toc
# for documentation.
# Original code Copyright 2008 [Jack Miller](https://codezen.org/)
# All changes Copyright 2008-2014 The Python Markdown Project
# License: [BSD](https://opensource.org/licenses/bsd-license.php)
"""
Add table of contents support to Python-Markdown.
See the [documentation](https://Python-Markdown.github.io/extensions/toc)
for details.
"""
from __future__ import annotations
from . import Extension
from ..treeprocessors import Treeprocessor
from ..util import code_escape, parseBoolValue, AMP_SUBSTITUTE, HTML_PLACEHOLDER_RE, AtomicString
from ..treeprocessors import UnescapeTreeprocessor
import re
import html
import unicodedata
import xml.etree.ElementTree as etree
def slugify(value, separator, unicode=False):
""" Slugify a string, to make it URL friendly. """
if not unicode:
# Replace Extended Latin characters with ASCII, i.e. `žlutý` => `zluty`
value = unicodedata.normalize('NFKD', value)
value = value.encode('ascii', 'ignore').decode('ascii')
value = re.sub(r'[^\w\s-]', '', value).strip().lower()
return re.sub(r'[{}\s]+'.format(separator), separator, value)
def slugify_unicode(value, separator):
""" Slugify a string, to make it URL friendly while preserving Unicode characters. """
return slugify(value, separator, unicode=True)
IDCOUNT_RE = re.compile(r'^(.*)_([0-9]+)$')
def unique(id, ids):
""" Ensure id is unique in set of ids. Append '_1', '_2'... if not """
while id in ids or not id:
m = IDCOUNT_RE.match(id)
if m:
id = '%s_%d' % (m.group(1), int(m.group(2))+1)
else:
id = '%s_%d' % (id, 1)
ids.add(id)
return id
def get_name(el):
"""Get title name."""
text = []
for c in el.itertext():
if isinstance(c, AtomicString):
text.append(html.unescape(c))
else:
text.append(c)
return ''.join(text).strip()
def stashedHTML2text(text, md, strip_entities: bool = True):
""" Extract raw HTML from stash, reduce to plain text and swap with placeholder. """
def _html_sub(m):
""" Substitute raw html with plain text. """
try:
raw = md.htmlStash.rawHtmlBlocks[int(m.group(1))]
except (IndexError, TypeError): # pragma: no cover
return m.group(0)
# Strip out tags and/or entities - leaving text
res = re.sub(r'(<[^>]+>)', '', raw)
if strip_entities:
res = re.sub(r'(&[\#a-zA-Z0-9]+;)', '', res)
return res
return HTML_PLACEHOLDER_RE.sub(_html_sub, text)
def unescape(text):
""" Unescape escaped text. """
c = UnescapeTreeprocessor()
return c.unescape(text)
def nest_toc_tokens(toc_list):
"""Given an unsorted list with errors and skips, return a nested one.
[{'level': 1}, {'level': 2}]
=>
[{'level': 1, 'children': [{'level': 2, 'children': []}]}]
A wrong list is also converted:
[{'level': 2}, {'level': 1}]
=>
[{'level': 2, 'children': []}, {'level': 1, 'children': []}]
"""
ordered_list = []
if len(toc_list):
# Initialize everything by processing the first entry
last = toc_list.pop(0)
last['children'] = []
levels = [last['level']]
ordered_list.append(last)
parents = []
# Walk the rest nesting the entries properly
while toc_list:
t = toc_list.pop(0)
current_level = t['level']
t['children'] = []
# Reduce depth if current level < last item's level
if current_level < levels[-1]:
# Pop last level since we know we are less than it
levels.pop()
# Pop parents and levels we are less than or equal to
to_pop = 0
for p in reversed(parents):
if current_level <= p['level']:
to_pop += 1
else: # pragma: no cover
break
if to_pop:
levels = levels[:-to_pop]
parents = parents[:-to_pop]
# Note current level as last
levels.append(current_level)
# Level is the same, so append to
# the current parent (if available)
if current_level == levels[-1]:
(parents[-1]['children'] if parents
else ordered_list).append(t)
# Current level is > last item's level,
# So make last item a parent and append current as child
else:
last['children'].append(t)
parents.append(last)
levels.append(current_level)
last = t
return ordered_list
class TocTreeprocessor(Treeprocessor):
""" Step through document and build TOC. """
def __init__(self, md, config):
super().__init__(md)
self.marker = config["marker"]
self.title = config["title"]
self.base_level = int(config["baselevel"]) - 1
self.slugify = config["slugify"]
self.sep = config["separator"]
self.toc_class = config["toc_class"]
self.title_class = config["title_class"]
self.use_anchors = parseBoolValue(config["anchorlink"])
self.anchorlink_class = config["anchorlink_class"]
self.use_permalinks = parseBoolValue(config["permalink"], False)
if self.use_permalinks is None:
self.use_permalinks = config["permalink"]
self.permalink_class = config["permalink_class"]
self.permalink_title = config["permalink_title"]
self.permalink_leading = parseBoolValue(config["permalink_leading"], False)
self.header_rgx = re.compile("[Hh][123456]")
if isinstance(config["toc_depth"], str) and '-' in config["toc_depth"]:
self.toc_top, self.toc_bottom = [int(x) for x in config["toc_depth"].split('-')]
else:
self.toc_top = 1
self.toc_bottom = int(config["toc_depth"])
def iterparent(self, node):
""" Iterator wrapper to get allowed parent and child all at once. """
# We do not allow the marker inside a header as that
# would causes an endless loop of placing a new TOC
# inside previously generated TOC.
for child in node:
if not self.header_rgx.match(child.tag) and child.tag not in ['pre', 'code']:
yield node, child
yield from self.iterparent(child)
def replace_marker(self, root, elem) -> None:
""" Replace marker with elem. """
for (p, c) in self.iterparent(root):
text = ''.join(c.itertext()).strip()
if not text:
continue
# To keep the output from screwing up the
# validation by putting a `<div>` inside of a `<p>`
# we actually replace the `<p>` in its entirety.
# The `<p>` element may contain more than a single text content
# (`nl2br` can introduce a `<br>`). In this situation, `c.text` returns
# the very first content, ignore children contents or tail content.
# `len(c) == 0` is here to ensure there is only text in the `<p>`.
if c.text and c.text.strip() == self.marker and len(c) == 0:
for i in range(len(p)):
if p[i] == c:
p[i] = elem
break
def set_level(self, elem) -> None:
""" Adjust header level according to base level. """
level = int(elem.tag[-1]) + self.base_level
if level > 6:
level = 6
elem.tag = 'h%d' % level
def add_anchor(self, c, elem_id) -> None:
anchor = etree.Element("a")
anchor.text = c.text
anchor.attrib["href"] = "#" + elem_id
anchor.attrib["class"] = self.anchorlink_class
c.text = ""
for elem in c:
anchor.append(elem)
while len(c):
c.remove(c[0])
c.append(anchor)
def add_permalink(self, c, elem_id) -> None:
permalink = etree.Element("a")
permalink.text = ("%spara;" % AMP_SUBSTITUTE
if self.use_permalinks is True
else self.use_permalinks)
permalink.attrib["href"] = "#" + elem_id
permalink.attrib["class"] = self.permalink_class
if self.permalink_title:
permalink.attrib["title"] = self.permalink_title
if self.permalink_leading:
permalink.tail = c.text
c.text = ""
c.insert(0, permalink)
else:
c.append(permalink)
def build_toc_div(self, toc_list):
""" Return a string div given a toc list. """
div = etree.Element("div")
div.attrib["class"] = self.toc_class
# Add title to the div
if self.title:
header = etree.SubElement(div, "span")
if self.title_class:
header.attrib["class"] = self.title_class
header.text = self.title
def build_etree_ul(toc_list, parent):
ul = etree.SubElement(parent, "ul")
for item in toc_list:
# List item link, to be inserted into the toc div
li = etree.SubElement(ul, "li")
link = etree.SubElement(li, "a")
link.text = item.get('name', '')
link.attrib["href"] = '#' + item.get('id', '')
if item['children']:
build_etree_ul(item['children'], li)
return ul
build_etree_ul(toc_list, div)
if 'prettify' in self.md.treeprocessors:
self.md.treeprocessors['prettify'].run(div)
return div
def run(self, doc):
# Get a list of id attributes
used_ids = set()
for el in doc.iter():
if "id" in el.attrib:
used_ids.add(el.attrib["id"])
toc_tokens = []
for el in doc.iter():
if isinstance(el.tag, str) and self.header_rgx.match(el.tag):
self.set_level(el)
text = get_name(el)
# Do not override pre-existing ids
if "id" not in el.attrib:
innertext = unescape(stashedHTML2text(text, self.md))
el.attrib["id"] = unique(self.slugify(innertext, self.sep), used_ids)
if int(el.tag[-1]) >= self.toc_top and int(el.tag[-1]) <= self.toc_bottom:
toc_tokens.append({
'level': int(el.tag[-1]),
'id': el.attrib["id"],
'name': unescape(stashedHTML2text(
code_escape(el.attrib.get('data-toc-label', text)),
self.md, strip_entities=False
))
})
# Remove the data-toc-label attribute as it is no longer needed
if 'data-toc-label' in el.attrib:
del el.attrib['data-toc-label']
if self.use_anchors:
self.add_anchor(el, el.attrib["id"])
if self.use_permalinks not in [False, None]:
self.add_permalink(el, el.attrib["id"])
toc_tokens = nest_toc_tokens(toc_tokens)
div = self.build_toc_div(toc_tokens)
if self.marker:
self.replace_marker(doc, div)
# serialize and attach to markdown instance.
toc = self.md.serializer(div)
for pp in self.md.postprocessors:
toc = pp.run(toc)
self.md.toc_tokens = toc_tokens
self.md.toc = toc
class TocExtension(Extension):
TreeProcessorClass = TocTreeprocessor
def __init__(self, **kwargs):
self.config = {
'marker': [
'[TOC]',
'Text to find and replace with Table of Contents. Set to an empty string to disable. '
'Default: `[TOC]`.'
],
'title': [
'', 'Title to insert into TOC `<div>`. Default: an empty string.'
],
'title_class': [
'toctitle', 'CSS class used for the title. Default: `toctitle`.'
],
'toc_class': [
'toc', 'CSS class(es) used for the link. Default: `toclink`.'
],
'anchorlink': [
False, 'True if header should be a self link. Default: `False`.'
],
'anchorlink_class': [
'toclink', 'CSS class(es) used for the link. Defaults: `toclink`.'
],
'permalink': [
0, 'True or link text if a Sphinx-style permalink should be added. Default: `False`.'
],
'permalink_class': [
'headerlink', 'CSS class(es) used for the link. Default: `headerlink`.'
],
'permalink_title': [
'Permanent link', 'Title attribute of the permalink. Default: `Permanent link`.'
],
'permalink_leading': [
False,
'True if permalinks should be placed at start of the header, rather than end. Default: False.'
],
'baselevel': ['1', 'Base level for headers. Default: `1`.'],
'slugify': [
slugify, 'Function to generate anchors based on header text. Default: `slugify`.'
],
'separator': ['-', 'Word separator. Default: `-`.'],
'toc_depth': [
6,
'Define the range of section levels to include in the Table of Contents. A single integer '
'(b) defines the bottom section level (<h1>..<hb>) only. A string consisting of two digits '
'separated by a hyphen in between (`2-5`) defines the top (t) and the bottom (b) (<ht>..<hb>). '
'Default: `6` (bottom).'
],
}
""" Default configuration options. """
super().__init__(**kwargs)
def extendMarkdown(self, md):
""" Add TOC tree processor to Markdown. """
md.registerExtension(self)
self.md = md
self.reset()
tocext = self.TreeProcessorClass(md, self.getConfigs())
md.treeprocessors.register(tocext, 'toc', 5)
def reset(self) -> None:
self.md.toc = ''
self.md.toc_tokens = []
def makeExtension(**kwargs): # pragma: no cover
return TocExtension(**kwargs)

View File

@@ -0,0 +1,96 @@
# WikiLinks Extension for Python-Markdown
# ======================================
# Converts [[WikiLinks]] to relative links.
# See https://Python-Markdown.github.io/extensions/wikilinks
# for documentation.
# Original code Copyright [Waylan Limberg](http://achinghead.com/).
# All changes Copyright The Python Markdown Project
# License: [BSD](https://opensource.org/licenses/bsd-license.php)
"""
Converts `[[WikiLinks]]` to relative links.
See the [documentation](https://Python-Markdown.github.io/extensions/wikilinks)
for details.
"""
from __future__ import annotations
from . import Extension
from ..inlinepatterns import InlineProcessor
import xml.etree.ElementTree as etree
import re
def build_url(label, base, end):
""" Build a URL from the label, a base, and an end. """
clean_label = re.sub(r'([ ]+_)|(_[ ]+)|([ ]+)', '_', label)
return '{}{}{}'.format(base, clean_label, end)
class WikiLinkExtension(Extension):
""" Add inline processor to Markdown. """
def __init__(self, **kwargs):
self.config = {
'base_url': ['/', 'String to append to beginning or URL.'],
'end_url': ['/', 'String to append to end of URL.'],
'html_class': ['wikilink', 'CSS hook. Leave blank for none.'],
'build_url': [build_url, 'Callable formats URL from label.'],
}
""" Default configuration options. """
super().__init__(**kwargs)
def extendMarkdown(self, md):
self.md = md
# append to end of inline patterns
WIKILINK_RE = r'\[\[([\w0-9_ -]+)\]\]'
wikilinkPattern = WikiLinksInlineProcessor(WIKILINK_RE, self.getConfigs())
wikilinkPattern.md = md
md.inlinePatterns.register(wikilinkPattern, 'wikilink', 75)
class WikiLinksInlineProcessor(InlineProcessor):
""" Build link from `wikilink`. """
def __init__(self, pattern, config):
super().__init__(pattern)
self.config = config
def handleMatch(self, m, data):
if m.group(1).strip():
base_url, end_url, html_class = self._getMeta()
label = m.group(1).strip()
url = self.config['build_url'](label, base_url, end_url)
a = etree.Element('a')
a.text = label
a.set('href', url)
if html_class:
a.set('class', html_class)
else:
a = ''
return a, m.start(0), m.end(0)
def _getMeta(self):
""" Return meta data or `config` data. """
base_url = self.config['base_url']
end_url = self.config['end_url']
html_class = self.config['html_class']
if hasattr(self.md, 'Meta'):
if 'wiki_base_url' in self.md.Meta:
base_url = self.md.Meta['wiki_base_url'][0]
if 'wiki_end_url' in self.md.Meta:
end_url = self.md.Meta['wiki_end_url'][0]
if 'wiki_html_class' in self.md.Meta:
html_class = self.md.Meta['wiki_html_class'][0]
return base_url, end_url, html_class
def makeExtension(**kwargs): # pragma: no cover
return WikiLinkExtension(**kwargs)

View File

@@ -0,0 +1,334 @@
# Python Markdown
# A Python implementation of John Gruber's Markdown.
# Documentation: https://python-markdown.github.io/
# GitHub: https://github.com/Python-Markdown/markdown/
# PyPI: https://pypi.org/project/Markdown/
# Started by Manfred Stienstra (http://www.dwerg.net/).
# Maintained for a few years by Yuri Takhteyev (http://www.freewisdom.org).
# Currently maintained by Waylan Limberg (https://github.com/waylan),
# Dmitry Shachnev (https://github.com/mitya57) and Isaac Muse (https://github.com/facelessuser).
# Copyright 2007-2023 The Python Markdown Project (v. 1.7 and later)
# Copyright 2004, 2005, 2006 Yuri Takhteyev (v. 0.2-1.6b)
# Copyright 2004 Manfred Stienstra (the original version)
# License: BSD (see LICENSE.md for details).
"""
This module imports a copy of [`html.parser.HTMLParser`][] and modifies it heavily through monkey-patches.
A copy is imported rather than the module being directly imported as this ensures that the user can import
and use the unmodified library for their own needs.
"""
from __future__ import annotations
import re
import importlib.util
import sys
# Import a copy of the html.parser lib as `htmlparser` so we can monkeypatch it.
# Users can still do `from html import parser` and get the default behavior.
spec = importlib.util.find_spec('html.parser')
htmlparser = importlib.util.module_from_spec(spec)
spec.loader.exec_module(htmlparser)
sys.modules['htmlparser'] = htmlparser
# Monkeypatch `HTMLParser` to only accept `?>` to close Processing Instructions.
htmlparser.piclose = re.compile(r'\?>')
# Monkeypatch `HTMLParser` to only recognize entity references with a closing semicolon.
htmlparser.entityref = re.compile(r'&([a-zA-Z][-.a-zA-Z0-9]*);')
# Monkeypatch `HTMLParser` to no longer support partial entities. We are always feeding a complete block,
# so the 'incomplete' functionality is unnecessary. As the `entityref` regex is run right before incomplete,
# and the two regex are the same, then incomplete will simply never match and we avoid the logic within.
htmlparser.incomplete = htmlparser.entityref
# Monkeypatch `HTMLParser` to not accept a backtick in a tag name, attribute name, or bare value.
htmlparser.locatestarttagend_tolerant = re.compile(r"""
<[a-zA-Z][^`\t\n\r\f />\x00]* # tag name <= added backtick here
(?:[\s/]* # optional whitespace before attribute name
(?:(?<=['"\s/])[^`\s/>][^\s/=>]* # attribute name <= added backtick here
(?:\s*=+\s* # value indicator
(?:'[^']*' # LITA-enclosed value
|"[^"]*" # LIT-enclosed value
|(?!['"])[^`>\s]* # bare value <= added backtick here
)
(?:\s*,)* # possibly followed by a comma
)?(?:\s|/(?!>))*
)*
)?
\s* # trailing whitespace
""", re.VERBOSE)
# Match a blank line at the start of a block of text (two newlines).
# The newlines may be preceded by additional whitespace.
blank_line_re = re.compile(r'^([ ]*\n){2}')
class HTMLExtractor(htmlparser.HTMLParser):
"""
Extract raw HTML from text.
The raw HTML is stored in the [`htmlStash`][markdown.util.HtmlStash] of the
[`Markdown`][markdown.Markdown] instance passed to `md` and the remaining text
is stored in `cleandoc` as a list of strings.
"""
def __init__(self, md, *args, **kwargs):
if 'convert_charrefs' not in kwargs:
kwargs['convert_charrefs'] = False
# Block tags that should contain no content (self closing)
self.empty_tags = set(['hr'])
self.lineno_start_cache = [0]
# This calls self.reset
super().__init__(*args, **kwargs)
self.md = md
def reset(self):
"""Reset this instance. Loses all unprocessed data."""
self.inraw = False
self.intail = False
self.stack = [] # When `inraw==True`, stack contains a list of tags
self._cache = []
self.cleandoc = []
self.lineno_start_cache = [0]
super().reset()
def close(self):
"""Handle any buffered data."""
super().close()
if len(self.rawdata):
# Temp fix for https://bugs.python.org/issue41989
# TODO: remove this when the bug is fixed in all supported Python versions.
if self.convert_charrefs and not self.cdata_elem: # pragma: no cover
self.handle_data(htmlparser.unescape(self.rawdata))
else:
self.handle_data(self.rawdata)
# Handle any unclosed tags.
if len(self._cache):
self.cleandoc.append(self.md.htmlStash.store(''.join(self._cache)))
self._cache = []
@property
def line_offset(self) -> int:
"""Returns char index in `self.rawdata` for the start of the current line. """
for ii in range(len(self.lineno_start_cache)-1, self.lineno-1):
last_line_start_pos = self.lineno_start_cache[ii]
lf_pos = self.rawdata.find('\n', last_line_start_pos)
if lf_pos == -1:
# No more newlines found. Use end of raw data as start of line beyond end.
lf_pos = len(self.rawdata)
self.lineno_start_cache.append(lf_pos+1)
return self.lineno_start_cache[self.lineno-1]
def at_line_start(self) -> bool:
"""
Returns True if current position is at start of line.
Allows for up to three blank spaces at start of line.
"""
if self.offset == 0:
return True
if self.offset > 3:
return False
# Confirm up to first 3 chars are whitespace
return self.rawdata[self.line_offset:self.line_offset + self.offset].strip() == ''
def get_endtag_text(self, tag: str) -> str:
"""
Returns the text of the end tag.
If it fails to extract the actual text from the raw data, it builds a closing tag with `tag`.
"""
# Attempt to extract actual tag from raw source text
start = self.line_offset + self.offset
m = htmlparser.endendtag.search(self.rawdata, start)
if m:
return self.rawdata[start:m.end()]
else: # pragma: no cover
# Failed to extract from raw data. Assume well formed and lowercase.
return '</{}>'.format(tag)
def handle_starttag(self, tag: str, attrs: list[tuple[str, str]]):
# Handle tags that should always be empty and do not specify a closing tag
if tag in self.empty_tags:
self.handle_startendtag(tag, attrs)
return
if self.md.is_block_level(tag) and (self.intail or (self.at_line_start() and not self.inraw)):
# Started a new raw block. Prepare stack.
self.inraw = True
self.cleandoc.append('\n')
text = self.get_starttag_text()
if self.inraw:
self.stack.append(tag)
self._cache.append(text)
else:
self.cleandoc.append(text)
if tag in self.CDATA_CONTENT_ELEMENTS:
# This is presumably a standalone tag in a code span (see #1036).
self.clear_cdata_mode()
def handle_endtag(self, tag: str):
text = self.get_endtag_text(tag)
if self.inraw:
self._cache.append(text)
if tag in self.stack:
# Remove tag from stack
while self.stack:
if self.stack.pop() == tag:
break
if len(self.stack) == 0:
# End of raw block.
if blank_line_re.match(self.rawdata[self.line_offset + self.offset + len(text):]):
# Preserve blank line and end of raw block.
self._cache.append('\n')
else:
# More content exists after `endtag`.
self.intail = True
# Reset stack.
self.inraw = False
self.cleandoc.append(self.md.htmlStash.store(''.join(self._cache)))
# Insert blank line between this and next line.
self.cleandoc.append('\n\n')
self._cache = []
else:
self.cleandoc.append(text)
def handle_data(self, data: str):
if self.intail and '\n' in data:
self.intail = False
if self.inraw:
self._cache.append(data)
else:
self.cleandoc.append(data)
def handle_empty_tag(self, data: str, is_block: bool):
""" Handle empty tags (`<data>`). """
if self.inraw or self.intail:
# Append this to the existing raw block
self._cache.append(data)
elif self.at_line_start() and is_block:
# Handle this as a standalone raw block
if blank_line_re.match(self.rawdata[self.line_offset + self.offset + len(data):]):
# Preserve blank line after tag in raw block.
data += '\n'
else:
# More content exists after tag.
self.intail = True
item = self.cleandoc[-1] if self.cleandoc else ''
# If we only have one newline before block element, add another
if not item.endswith('\n\n') and item.endswith('\n'):
self.cleandoc.append('\n')
self.cleandoc.append(self.md.htmlStash.store(data))
# Insert blank line between this and next line.
self.cleandoc.append('\n\n')
else:
self.cleandoc.append(data)
def handle_startendtag(self, tag: str, attrs: list[tuple[str, str]]):
self.handle_empty_tag(self.get_starttag_text(), is_block=self.md.is_block_level(tag))
def handle_charref(self, name: str):
self.handle_empty_tag('&#{};'.format(name), is_block=False)
def handle_entityref(self, name: str):
self.handle_empty_tag('&{};'.format(name), is_block=False)
def handle_comment(self, data: str):
self.handle_empty_tag('<!--{}-->'.format(data), is_block=True)
def handle_decl(self, data: str):
self.handle_empty_tag('<!{}>'.format(data), is_block=True)
def handle_pi(self, data: str):
self.handle_empty_tag('<?{}?>'.format(data), is_block=True)
def unknown_decl(self, data: str):
end = ']]>' if data.startswith('CDATA[') else ']>'
self.handle_empty_tag('<![{}{}'.format(data, end), is_block=True)
def parse_pi(self, i: int) -> int:
if self.at_line_start() or self.intail:
return super().parse_pi(i)
# This is not the beginning of a raw block so treat as plain data
# and avoid consuming any tags which may follow (see #1066).
self.handle_data('<?')
return i + 2
def parse_html_declaration(self, i: int) -> int:
if self.at_line_start() or self.intail:
return super().parse_html_declaration(i)
# This is not the beginning of a raw block so treat as plain data
# and avoid consuming any tags which may follow (see #1066).
self.handle_data('<!')
return i + 2
# The rest has been copied from base class in standard lib to address #1036.
# As `__startag_text` is private, all references to it must be in this subclass.
# The last few lines of `parse_starttag` are reversed so that `handle_starttag`
# can override `cdata_mode` in certain situations (in a code span).
__starttag_text: str | None = None
def get_starttag_text(self) -> str:
"""Return full source of start tag: `<...>`."""
return self.__starttag_text
def parse_starttag(self, i: int) -> int: # pragma: no cover
self.__starttag_text = None
endpos = self.check_for_whole_start_tag(i)
if endpos < 0:
return endpos
rawdata = self.rawdata
self.__starttag_text = rawdata[i:endpos]
# Now parse the data between `i+1` and `j` into a tag and `attrs`
attrs = []
match = htmlparser.tagfind_tolerant.match(rawdata, i+1)
assert match, 'unexpected call to parse_starttag()'
k = match.end()
self.lasttag = tag = match.group(1).lower()
while k < endpos:
m = htmlparser.attrfind_tolerant.match(rawdata, k)
if not m:
break
attrname, rest, attrvalue = m.group(1, 2, 3)
if not rest:
attrvalue = None
elif attrvalue[:1] == '\'' == attrvalue[-1:] or \
attrvalue[:1] == '"' == attrvalue[-1:]: # noqa: E127
attrvalue = attrvalue[1:-1]
if attrvalue:
attrvalue = htmlparser.unescape(attrvalue)
attrs.append((attrname.lower(), attrvalue))
k = m.end()
end = rawdata[k:endpos].strip()
if end not in (">", "/>"):
lineno, offset = self.getpos()
if "\n" in self.__starttag_text:
lineno = lineno + self.__starttag_text.count("\n")
offset = len(self.__starttag_text) \
- self.__starttag_text.rfind("\n") # noqa: E127
else:
offset = offset + len(self.__starttag_text)
self.handle_data(rawdata[i:endpos])
return endpos
if end.endswith('/>'):
# XHTML-style empty tag: `<span attr="value" />`
self.handle_startendtag(tag, attrs)
else:
# *** set `cdata_mode` first so we can override it in `handle_starttag` (see #1036) ***
if tag in self.CDATA_CONTENT_ELEMENTS:
self.set_cdata_mode(tag)
self.handle_starttag(tag, attrs)
return endpos

View File

@@ -0,0 +1,992 @@
# Python Markdown
# A Python implementation of John Gruber's Markdown.
# Documentation: https://python-markdown.github.io/
# GitHub: https://github.com/Python-Markdown/markdown/
# PyPI: https://pypi.org/project/Markdown/
# Started by Manfred Stienstra (http://www.dwerg.net/).
# Maintained for a few years by Yuri Takhteyev (http://www.freewisdom.org).
# Currently maintained by Waylan Limberg (https://github.com/waylan),
# Dmitry Shachnev (https://github.com/mitya57) and Isaac Muse (https://github.com/facelessuser).
# Copyright 2007-2023 The Python Markdown Project (v. 1.7 and later)
# Copyright 2004, 2005, 2006 Yuri Takhteyev (v. 0.2-1.6b)
# Copyright 2004 Manfred Stienstra (the original version)
# License: BSD (see LICENSE.md for details).
"""
In version 3.0, a new, more flexible inline processor was added, [`markdown.inlinepatterns.InlineProcessor`][]. The
original inline patterns, which inherit from [`markdown.inlinepatterns.Pattern`][] or one of its children are still
supported, though users are encouraged to migrate.
The new `InlineProcessor` provides two major enhancements to `Patterns`:
1. Inline Processors no longer need to match the entire block, so regular expressions no longer need to start with
`r'^(.*?)'` and end with `r'(.*?)%'`. This runs faster. The returned [`Match`][re.Match] object will only contain
what is explicitly matched in the pattern, and extension pattern groups now start with `m.group(1)`.
2. The `handleMatch` method now takes an additional input called `data`, which is the entire block under analysis,
not just what is matched with the specified pattern. The method now returns the element *and* the indexes relative
to `data` that the return element is replacing (usually `m.start(0)` and `m.end(0)`). If the boundaries are
returned as `None`, it is assumed that the match did not take place, and nothing will be altered in `data`.
This allows handling of more complex constructs than regular expressions can handle, e.g., matching nested
brackets, and explicit control of the span "consumed" by the processor.
"""
from __future__ import annotations
from . import util
from typing import TYPE_CHECKING, Any, Collection, NamedTuple
import re
import xml.etree.ElementTree as etree
try: # pragma: no cover
from html import entities
except ImportError: # pragma: no cover
import htmlentitydefs as entities
if TYPE_CHECKING: # pragma: no cover
from markdown import Markdown
def build_inlinepatterns(md: Markdown, **kwargs: Any) -> util.Registry[InlineProcessor]:
"""
Build the default set of inline patterns for Markdown.
The order in which processors and/or patterns are applied is very important - e.g. if we first replace
`http://.../` links with `<a>` tags and _then_ try to replace inline HTML, we would end up with a mess. So, we
apply the expressions in the following order:
* backticks and escaped characters have to be handled before everything else so that we can preempt any markdown
patterns by escaping them;
* then we handle the various types of links (auto-links must be handled before inline HTML);
* then we handle inline HTML. At this point we will simply replace all inline HTML strings with a placeholder
and add the actual HTML to a stash;
* finally we apply strong, emphasis, etc.
"""
inlinePatterns = util.Registry()
inlinePatterns.register(BacktickInlineProcessor(BACKTICK_RE), 'backtick', 190)
inlinePatterns.register(EscapeInlineProcessor(ESCAPE_RE, md), 'escape', 180)
inlinePatterns.register(ReferenceInlineProcessor(REFERENCE_RE, md), 'reference', 170)
inlinePatterns.register(LinkInlineProcessor(LINK_RE, md), 'link', 160)
inlinePatterns.register(ImageInlineProcessor(IMAGE_LINK_RE, md), 'image_link', 150)
inlinePatterns.register(
ImageReferenceInlineProcessor(IMAGE_REFERENCE_RE, md), 'image_reference', 140
)
inlinePatterns.register(
ShortReferenceInlineProcessor(REFERENCE_RE, md), 'short_reference', 130
)
inlinePatterns.register(
ShortImageReferenceInlineProcessor(IMAGE_REFERENCE_RE, md), 'short_image_ref', 125
)
inlinePatterns.register(AutolinkInlineProcessor(AUTOLINK_RE, md), 'autolink', 120)
inlinePatterns.register(AutomailInlineProcessor(AUTOMAIL_RE, md), 'automail', 110)
inlinePatterns.register(SubstituteTagInlineProcessor(LINE_BREAK_RE, 'br'), 'linebreak', 100)
inlinePatterns.register(HtmlInlineProcessor(HTML_RE, md), 'html', 90)
inlinePatterns.register(HtmlInlineProcessor(ENTITY_RE, md), 'entity', 80)
inlinePatterns.register(SimpleTextInlineProcessor(NOT_STRONG_RE), 'not_strong', 70)
inlinePatterns.register(AsteriskProcessor(r'\*'), 'em_strong', 60)
inlinePatterns.register(UnderscoreProcessor(r'_'), 'em_strong2', 50)
return inlinePatterns
# The actual regular expressions for patterns
# -----------------------------------------------------------------------------
NOIMG = r'(?<!\!)'
""" Match not an image. Partial regular expression which matches if not preceded by `!`. """
BACKTICK_RE = r'(?:(?<!\\)((?:\\{2})+)(?=`+)|(?<!\\)(`+)(.+?)(?<!`)\2(?!`))'
""" Match backtick quoted string (`` `e=f()` `` or ``` ``e=f("`")`` ```). """
ESCAPE_RE = r'\\(.)'
""" Match a backslash escaped character (`\\<` or `\\*`). """
EMPHASIS_RE = r'(\*)([^\*]+)\1'
""" Match emphasis with an asterisk (`*emphasis*`). """
STRONG_RE = r'(\*{2})(.+?)\1'
""" Match strong with an asterisk (`**strong**`). """
SMART_STRONG_RE = r'(?<!\w)(_{2})(?!_)(.+?)(?<!_)\1(?!\w)'
""" Match strong with underscore while ignoring middle word underscores (`__smart__strong__`). """
SMART_EMPHASIS_RE = r'(?<!\w)(_)(?!_)(.+?)(?<!_)\1(?!\w)'
""" Match emphasis with underscore while ignoring middle word underscores (`_smart_emphasis_`). """
SMART_STRONG_EM_RE = r'(?<!\w)(\_)\1(?!\1)(.+?)(?<!\w)\1(?!\1)(.+?)\1{3}(?!\w)'
""" Match strong emphasis with underscores (`__strong _em__`). """
EM_STRONG_RE = r'(\*)\1{2}(.+?)\1(.*?)\1{2}'
""" Match emphasis strong with asterisk (`***strongem***` or `***em*strong**`). """
EM_STRONG2_RE = r'(_)\1{2}(.+?)\1(.*?)\1{2}'
""" Match emphasis strong with underscores (`___emstrong___` or `___em_strong__`). """
STRONG_EM_RE = r'(\*)\1{2}(.+?)\1{2}(.*?)\1'
""" Match strong emphasis with asterisk (`***strong**em*`). """
STRONG_EM2_RE = r'(_)\1{2}(.+?)\1{2}(.*?)\1'
""" Match strong emphasis with underscores (`___strong__em_`). """
STRONG_EM3_RE = r'(\*)\1(?!\1)([^*]+?)\1(?!\1)(.+?)\1{3}'
""" Match strong emphasis with asterisk (`**strong*em***`). """
LINK_RE = NOIMG + r'\['
""" Match start of in-line link (`[text](url)` or `[text](<url>)` or `[text](url "title")`). """
IMAGE_LINK_RE = r'\!\['
""" Match start of in-line image link (`![alttxt](url)` or `![alttxt](<url>)`). """
REFERENCE_RE = LINK_RE
""" Match start of reference link (`[Label][3]`). """
IMAGE_REFERENCE_RE = IMAGE_LINK_RE
""" Match start of image reference (`![alt text][2]`). """
NOT_STRONG_RE = r'((^|(?<=\s))(\*{1,3}|_{1,3})(?=\s|$))'
""" Match a stand-alone `*` or `_`. """
AUTOLINK_RE = r'<((?:[Ff]|[Hh][Tt])[Tt][Pp][Ss]?://[^<>]*)>'
""" Match an automatic link (`<http://www.example.com>`). """
AUTOMAIL_RE = r'<([^<> !]+@[^@<> ]+)>'
""" Match an automatic email link (`<me@example.com>`). """
HTML_RE = r'(<(\/?[a-zA-Z][^<>@ ]*( [^<>]*)?|!--(?:(?!<!--|-->).)*--)>)'
""" Match an HTML tag (`<...>`). """
ENTITY_RE = r'(&(?:\#[0-9]+|\#x[0-9a-fA-F]+|[a-zA-Z0-9]+);)'
""" Match an HTML entity (`&#38;` (decimal) or `&#x26;` (hex) or `&amp;` (named)). """
LINE_BREAK_RE = r' \n'
""" Match two spaces at end of line. """
def dequote(string: str) -> str:
"""Remove quotes from around a string."""
if ((string.startswith('"') and string.endswith('"')) or
(string.startswith("'") and string.endswith("'"))):
return string[1:-1]
else:
return string
class EmStrongItem(NamedTuple):
"""Emphasis/strong pattern item."""
pattern: re.Pattern[str]
builder: str
tags: str
# The pattern classes
# -----------------------------------------------------------------------------
class Pattern: # pragma: no cover
"""
Base class that inline patterns subclass.
Inline patterns are handled by means of `Pattern` subclasses, one per regular expression.
Each pattern object uses a single regular expression and must support the following methods:
[`getCompiledRegExp`][markdown.inlinepatterns.Pattern.getCompiledRegExp] and
[`handleMatch`][markdown.inlinepatterns.Pattern.handleMatch].
All the regular expressions used by `Pattern` subclasses must capture the whole block. For this
reason, they all start with `^(.*)` and end with `(.*)!`. When passing a regular expression on
class initialization, the `^(.*)` and `(.*)!` are added automatically and the regular expression
is pre-compiled.
It is strongly suggested that the newer style [`markdown.inlinepatterns.InlineProcessor`][] that
use a more efficient and flexible search approach be used instead. However, the older style
`Pattern` remains for backward compatibility with many existing third-party extensions.
"""
ANCESTOR_EXCLUDES: Collection[str] = tuple()
"""
A collection of elements which are undesirable ancestors. The processor will be skipped if it
would cause the content to be a descendant of one of the listed tag names.
"""
def __init__(self, pattern: str, md: Markdown | None = None):
"""
Create an instant of an inline pattern.
Arguments:
pattern: A regular expression that matches a pattern.
md: An optional pointer to the instance of `markdown.Markdown` and is available as
`self.md` on the class instance.
"""
self.pattern = pattern
self.compiled_re = re.compile(r"^(.*?)%s(.*)$" % pattern,
re.DOTALL | re.UNICODE)
self.md = md
def getCompiledRegExp(self) -> re.Pattern:
""" Return a compiled regular expression. """
return self.compiled_re
def handleMatch(self, m: re.Match[str]) -> etree.Element | str:
"""Return a ElementTree element from the given match.
Subclasses should override this method.
Arguments:
m: A match object containing a match of the pattern.
Returns: An ElementTree Element object.
"""
pass # pragma: no cover
def type(self) -> str:
""" Return class name, to define pattern type """
return self.__class__.__name__
def unescape(self, text: str) -> str:
""" Return unescaped text given text with an inline placeholder. """
try:
stash = self.md.treeprocessors['inline'].stashed_nodes
except KeyError: # pragma: no cover
return text
def get_stash(m):
id = m.group(1)
if id in stash:
value = stash.get(id)
if isinstance(value, str):
return value
else:
# An `etree` Element - return text content only
return ''.join(value.itertext())
return util.INLINE_PLACEHOLDER_RE.sub(get_stash, text)
class InlineProcessor(Pattern):
"""
Base class that inline processors subclass.
This is the newer style inline processor that uses a more
efficient and flexible search approach.
"""
def __init__(self, pattern: str, md: Markdown | None = None):
"""
Create an instant of an inline processor.
Arguments:
pattern: A regular expression that matches a pattern.
md: An optional pointer to the instance of `markdown.Markdown` and is available as
`self.md` on the class instance.
"""
self.pattern = pattern
self.compiled_re = re.compile(pattern, re.DOTALL | re.UNICODE)
# API for Markdown to pass `safe_mode` into instance
self.safe_mode = False
self.md = md
def handleMatch(self, m: re.Match[str], data: str) -> tuple[etree.Element | str | None, int | None, int | None]:
"""Return a ElementTree element from the given match and the
start and end index of the matched text.
If `start` and/or `end` are returned as `None`, it will be
assumed that the processor did not find a valid region of text.
Subclasses should override this method.
Arguments:
m: A re match object containing a match of the pattern.
data: The buffer currently under analysis.
Returns:
el: The ElementTree element, text or None.
start: The start of the region that has been matched or None.
end: The end of the region that has been matched or None.
"""
pass # pragma: no cover
class SimpleTextPattern(Pattern): # pragma: no cover
""" Return a simple text of `group(2)` of a Pattern. """
def handleMatch(self, m: re.Match[str]) -> str:
""" Return string content of `group(2)` of a matching pattern. """
return m.group(2)
class SimpleTextInlineProcessor(InlineProcessor):
""" Return a simple text of `group(1)` of a Pattern. """
def handleMatch(self, m: re.Match[str], data: str) -> tuple[str, int, int]:
""" Return string content of `group(1)` of a matching pattern. """
return m.group(1), m.start(0), m.end(0)
class EscapeInlineProcessor(InlineProcessor):
""" Return an escaped character. """
def handleMatch(self, m: re.Match[str], data: str) -> tuple[str | None, int, int]:
"""
If the character matched by `group(1)` of a pattern is in [`ESCAPED_CHARS`][markdown.Markdown.ESCAPED_CHARS]
then return the integer representing the character's Unicode code point (as returned by [`ord`][]) wrapped
in [`util.STX`][markdown.util.STX] and [`util.ETX`][markdown.util.ETX].
If the matched character is not in [`ESCAPED_CHARS`][markdown.Markdown.ESCAPED_CHARS], then return `None`.
"""
char = m.group(1)
if char in self.md.ESCAPED_CHARS:
return '{}{}{}'.format(util.STX, ord(char), util.ETX), m.start(0), m.end(0)
else:
return None, m.start(0), m.end(0)
class SimpleTagPattern(Pattern): # pragma: no cover
"""
Return element of type `tag` with a text attribute of `group(3)`
of a Pattern.
"""
def __init__(self, pattern: str, tag: str):
"""
Create an instant of an simple tag pattern.
Arguments:
pattern: A regular expression that matches a pattern.
tag: Tag of element.
"""
Pattern.__init__(self, pattern)
self.tag = tag
""" The tag of the rendered element. """
def handleMatch(self, m: re.Match[str]) -> etree.Element:
"""
Return [`Element`][xml.etree.ElementTree.Element] of type `tag` with the string in `group(3)` of a
matching pattern as the Element's text.
"""
el = etree.Element(self.tag)
el.text = m.group(3)
return el
class SimpleTagInlineProcessor(InlineProcessor):
"""
Return element of type `tag` with a text attribute of `group(2)`
of a Pattern.
"""
def __init__(self, pattern: str, tag: str):
"""
Create an instant of an simple tag processor.
Arguments:
pattern: A regular expression that matches a pattern.
tag: Tag of element.
"""
InlineProcessor.__init__(self, pattern)
self.tag = tag
""" The tag of the rendered element. """
def handleMatch(self, m: re.Match[str], data: str) -> tuple[etree.Element, int, int]: # pragma: no cover
"""
Return [`Element`][xml.etree.ElementTree.Element] of type `tag` with the string in `group(2)` of a
matching pattern as the Element's text.
"""
el = etree.Element(self.tag)
el.text = m.group(2)
return el, m.start(0), m.end(0)
class SubstituteTagPattern(SimpleTagPattern): # pragma: no cover
""" Return an element of type `tag` with no children. """
def handleMatch(self, m: re.Match[str]) -> etree.Element:
""" Return empty [`Element`][xml.etree.ElementTree.Element] of type `tag`. """
return etree.Element(self.tag)
class SubstituteTagInlineProcessor(SimpleTagInlineProcessor):
""" Return an element of type `tag` with no children. """
def handleMatch(self, m: re.Match[str], data: str) -> tuple[etree.Element, int, int]:
""" Return empty [`Element`][xml.etree.ElementTree.Element] of type `tag`. """
return etree.Element(self.tag), m.start(0), m.end(0)
class BacktickInlineProcessor(InlineProcessor):
""" Return a `<code>` element containing the escaped matching text. """
def __init__(self, pattern):
InlineProcessor.__init__(self, pattern)
self.ESCAPED_BSLASH = '{}{}{}'.format(util.STX, ord('\\'), util.ETX)
self.tag = 'code'
""" The tag of the rendered element. """
def handleMatch(self, m: re.Match[str], data: str) -> tuple[etree.Element | str, int, int]:
"""
If the match contains `group(3)` of a pattern, then return a `code`
[`Element`][xml.etree.ElementTree.Element] which contains HTML escaped text (with
[`code_escape`][markdown.util.code_escape]) as an [`AtomicString`][markdown.util.AtomicString].
If the match does not contain `group(3)` then return the text of `group(1)` backslash escaped.
"""
if m.group(3):
el = etree.Element(self.tag)
el.text = util.AtomicString(util.code_escape(m.group(3).strip()))
return el, m.start(0), m.end(0)
else:
return m.group(1).replace('\\\\', self.ESCAPED_BSLASH), m.start(0), m.end(0)
class DoubleTagPattern(SimpleTagPattern): # pragma: no cover
"""Return a ElementTree element nested in tag2 nested in tag1.
Useful for strong emphasis etc.
"""
def handleMatch(self, m: re.Match[str]) -> etree.Element:
"""
Return [`Element`][xml.etree.ElementTree.Element] in following format:
`<tag1><tag2>group(3)</tag2>group(4)</tag2>` where `group(4)` is optional.
"""
tag1, tag2 = self.tag.split(",")
el1 = etree.Element(tag1)
el2 = etree.SubElement(el1, tag2)
el2.text = m.group(3)
if len(m.groups()) == 5:
el2.tail = m.group(4)
return el1
class DoubleTagInlineProcessor(SimpleTagInlineProcessor):
"""Return a ElementTree element nested in tag2 nested in tag1.
Useful for strong emphasis etc.
"""
def handleMatch(self, m: re.Match[str], data: str) -> tuple[etree.Element, int, int]: # pragma: no cover
"""
Return [`Element`][xml.etree.ElementTree.Element] in following format:
`<tag1><tag2>group(2)</tag2>group(3)</tag2>` where `group(3)` is optional.
"""
tag1, tag2 = self.tag.split(",")
el1 = etree.Element(tag1)
el2 = etree.SubElement(el1, tag2)
el2.text = m.group(2)
if len(m.groups()) == 3:
el2.tail = m.group(3)
return el1, m.start(0), m.end(0)
class HtmlInlineProcessor(InlineProcessor):
""" Store raw inline html and return a placeholder. """
def handleMatch(self, m: re.Match[str], data: str) -> tuple[str, int, int]:
""" Store the text of `group(1)` of a pattern and return a placeholder string. """
rawhtml = self.backslash_unescape(self.unescape(m.group(1)))
place_holder = self.md.htmlStash.store(rawhtml)
return place_holder, m.start(0), m.end(0)
def unescape(self, text):
""" Return unescaped text given text with an inline placeholder. """
try:
stash = self.md.treeprocessors['inline'].stashed_nodes
except KeyError: # pragma: no cover
return text
def get_stash(m):
id = m.group(1)
value = stash.get(id)
if value is not None:
try:
return self.md.serializer(value)
except Exception:
return r'\%s' % value
return util.INLINE_PLACEHOLDER_RE.sub(get_stash, text)
def backslash_unescape(self, text):
""" Return text with backslash escapes undone (backslashes are restored). """
try:
RE = self.md.treeprocessors['unescape'].RE
except KeyError: # pragma: no cover
return text
def _unescape(m):
return chr(int(m.group(1)))
return RE.sub(_unescape, text)
class AsteriskProcessor(InlineProcessor):
"""Emphasis processor for handling strong and em matches inside asterisks."""
PATTERNS = [
EmStrongItem(re.compile(EM_STRONG_RE, re.DOTALL | re.UNICODE), 'double', 'strong,em'),
EmStrongItem(re.compile(STRONG_EM_RE, re.DOTALL | re.UNICODE), 'double', 'em,strong'),
EmStrongItem(re.compile(STRONG_EM3_RE, re.DOTALL | re.UNICODE), 'double2', 'strong,em'),
EmStrongItem(re.compile(STRONG_RE, re.DOTALL | re.UNICODE), 'single', 'strong'),
EmStrongItem(re.compile(EMPHASIS_RE, re.DOTALL | re.UNICODE), 'single', 'em')
]
""" The various strong and emphasis patterns handled by this processor. """
def build_single(self, m, tag, idx):
"""Return single tag."""
el1 = etree.Element(tag)
text = m.group(2)
self.parse_sub_patterns(text, el1, None, idx)
return el1
def build_double(self, m, tags, idx):
"""Return double tag."""
tag1, tag2 = tags.split(",")
el1 = etree.Element(tag1)
el2 = etree.Element(tag2)
text = m.group(2)
self.parse_sub_patterns(text, el2, None, idx)
el1.append(el2)
if len(m.groups()) == 3:
text = m.group(3)
self.parse_sub_patterns(text, el1, el2, idx)
return el1
def build_double2(self, m, tags, idx):
"""Return double tags (variant 2): `<strong>text <em>text</em></strong>`."""
tag1, tag2 = tags.split(",")
el1 = etree.Element(tag1)
el2 = etree.Element(tag2)
text = m.group(2)
self.parse_sub_patterns(text, el1, None, idx)
text = m.group(3)
el1.append(el2)
self.parse_sub_patterns(text, el2, None, idx)
return el1
def parse_sub_patterns(self, data, parent, last, idx) -> None:
"""
Parses sub patterns.
`data` (`str`):
text to evaluate.
`parent` (`etree.Element`):
Parent to attach text and sub elements to.
`last` (`etree.Element`):
Last appended child to parent. Can also be None if parent has no children.
`idx` (`int`):
Current pattern index that was used to evaluate the parent.
"""
offset = 0
pos = 0
length = len(data)
while pos < length:
# Find the start of potential emphasis or strong tokens
if self.compiled_re.match(data, pos):
matched = False
# See if the we can match an emphasis/strong pattern
for index, item in enumerate(self.PATTERNS):
# Only evaluate patterns that are after what was used on the parent
if index <= idx:
continue
m = item.pattern.match(data, pos)
if m:
# Append child nodes to parent
# Text nodes should be appended to the last
# child if present, and if not, it should
# be added as the parent's text node.
text = data[offset:m.start(0)]
if text:
if last is not None:
last.tail = text
else:
parent.text = text
el = self.build_element(m, item.builder, item.tags, index)
parent.append(el)
last = el
# Move our position past the matched hunk
offset = pos = m.end(0)
matched = True
if not matched:
# We matched nothing, move on to the next character
pos += 1
else:
# Increment position as no potential emphasis start was found.
pos += 1
# Append any leftover text as a text node.
text = data[offset:]
if text:
if last is not None:
last.tail = text
else:
parent.text = text
def build_element(self, m, builder, tags, index):
"""Element builder."""
if builder == 'double2':
return self.build_double2(m, tags, index)
elif builder == 'double':
return self.build_double(m, tags, index)
else:
return self.build_single(m, tags, index)
def handleMatch(self, m: re.Match[str], data: str) -> tuple[etree.Element | None, int | None, int | None]:
"""Parse patterns."""
el = None
start = None
end = None
for index, item in enumerate(self.PATTERNS):
m1 = item.pattern.match(data, m.start(0))
if m1:
start = m1.start(0)
end = m1.end(0)
el = self.build_element(m1, item.builder, item.tags, index)
break
return el, start, end
class UnderscoreProcessor(AsteriskProcessor):
"""Emphasis processor for handling strong and em matches inside underscores."""
PATTERNS = [
EmStrongItem(re.compile(EM_STRONG2_RE, re.DOTALL | re.UNICODE), 'double', 'strong,em'),
EmStrongItem(re.compile(STRONG_EM2_RE, re.DOTALL | re.UNICODE), 'double', 'em,strong'),
EmStrongItem(re.compile(SMART_STRONG_EM_RE, re.DOTALL | re.UNICODE), 'double2', 'strong,em'),
EmStrongItem(re.compile(SMART_STRONG_RE, re.DOTALL | re.UNICODE), 'single', 'strong'),
EmStrongItem(re.compile(SMART_EMPHASIS_RE, re.DOTALL | re.UNICODE), 'single', 'em')
]
""" The various strong and emphasis patterns handled by this processor. """
class LinkInlineProcessor(InlineProcessor):
""" Return a link element from the given match. """
RE_LINK = re.compile(r'''\(\s*(?:(<[^<>]*>)\s*(?:('[^']*'|"[^"]*")\s*)?\))?''', re.DOTALL | re.UNICODE)
RE_TITLE_CLEAN = re.compile(r'\s')
def handleMatch(self, m: re.Match[str], data: str) -> tuple[etree.Element | None, int | None, int | None]:
""" Return an `a` [`Element`][xml.etree.ElementTree.Element] or `(None, None, None)`. """
text, index, handled = self.getText(data, m.end(0))
if not handled:
return None, None, None
href, title, index, handled = self.getLink(data, index)
if not handled:
return None, None, None
el = etree.Element("a")
el.text = text
el.set("href", href)
if title is not None:
el.set("title", title)
return el, m.start(0), index
def getLink(self, data, index):
"""Parse data between `()` of `[Text]()` allowing recursive `()`. """
href = ''
title = None
handled = False
m = self.RE_LINK.match(data, pos=index)
if m and m.group(1):
# Matches [Text](<link> "title")
href = m.group(1)[1:-1].strip()
if m.group(2):
title = m.group(2)[1:-1]
index = m.end(0)
handled = True
elif m:
# Track bracket nesting and index in string
bracket_count = 1
backtrack_count = 1
start_index = m.end()
index = start_index
last_bracket = -1
# Primary (first found) quote tracking.
quote = None
start_quote = -1
exit_quote = -1
ignore_matches = False
# Secondary (second found) quote tracking.
alt_quote = None
start_alt_quote = -1
exit_alt_quote = -1
# Track last character
last = ''
for pos in range(index, len(data)):
c = data[pos]
if c == '(':
# Count nested (
# Don't increment the bracket count if we are sure we're in a title.
if not ignore_matches:
bracket_count += 1
elif backtrack_count > 0:
backtrack_count -= 1
elif c == ')':
# Match nested ) to (
# Don't decrement if we are sure we are in a title that is unclosed.
if ((exit_quote != -1 and quote == last) or (exit_alt_quote != -1 and alt_quote == last)):
bracket_count = 0
elif not ignore_matches:
bracket_count -= 1
elif backtrack_count > 0:
backtrack_count -= 1
# We've found our backup end location if the title doesn't resolve.
if backtrack_count == 0:
last_bracket = index + 1
elif c in ("'", '"'):
# Quote has started
if not quote:
# We'll assume we are now in a title.
# Brackets are quoted, so no need to match them (except for the final one).
ignore_matches = True
backtrack_count = bracket_count
bracket_count = 1
start_quote = index + 1
quote = c
# Secondary quote (in case the first doesn't resolve): [text](link'"title")
elif c != quote and not alt_quote:
start_alt_quote = index + 1
alt_quote = c
# Update primary quote match
elif c == quote:
exit_quote = index + 1
# Update secondary quote match
elif alt_quote and c == alt_quote:
exit_alt_quote = index + 1
index += 1
# Link is closed, so let's break out of the loop
if bracket_count == 0:
# Get the title if we closed a title string right before link closed
if exit_quote >= 0 and quote == last:
href = data[start_index:start_quote - 1]
title = ''.join(data[start_quote:exit_quote - 1])
elif exit_alt_quote >= 0 and alt_quote == last:
href = data[start_index:start_alt_quote - 1]
title = ''.join(data[start_alt_quote:exit_alt_quote - 1])
else:
href = data[start_index:index - 1]
break
if c != ' ':
last = c
# We have a scenario: `[test](link"notitle)`
# When we enter a string, we stop tracking bracket resolution in the main counter,
# but we do keep a backup counter up until we discover where we might resolve all brackets
# if the title string fails to resolve.
if bracket_count != 0 and backtrack_count == 0:
href = data[start_index:last_bracket - 1]
index = last_bracket
bracket_count = 0
handled = bracket_count == 0
if title is not None:
title = self.RE_TITLE_CLEAN.sub(' ', dequote(self.unescape(title.strip())))
href = self.unescape(href).strip()
return href, title, index, handled
def getText(self, data, index):
"""Parse the content between `[]` of the start of an image or link
resolving nested square brackets.
"""
bracket_count = 1
text = []
for pos in range(index, len(data)):
c = data[pos]
if c == ']':
bracket_count -= 1
elif c == '[':
bracket_count += 1
index += 1
if bracket_count == 0:
break
text.append(c)
return ''.join(text), index, bracket_count == 0
class ImageInlineProcessor(LinkInlineProcessor):
""" Return a `img` element from the given match. """
def handleMatch(self, m: re.Match[str], data: str) -> tuple[etree.Element | None, int | None, int | None]:
""" Return an `img` [`Element`][xml.etree.ElementTree.Element] or `(None, None, None)`. """
text, index, handled = self.getText(data, m.end(0))
if not handled:
return None, None, None
src, title, index, handled = self.getLink(data, index)
if not handled:
return None, None, None
el = etree.Element("img")
el.set("src", src)
if title is not None:
el.set("title", title)
el.set('alt', self.unescape(text))
return el, m.start(0), index
class ReferenceInlineProcessor(LinkInlineProcessor):
""" Match to a stored reference and return link element. """
NEWLINE_CLEANUP_RE = re.compile(r'\s+', re.MULTILINE)
RE_LINK = re.compile(r'\s?\[([^\]]*)\]', re.DOTALL | re.UNICODE)
def handleMatch(self, m: re.Match[str], data: str) -> tuple[etree.Element | None, int | None, int | None]:
"""
Return [`Element`][xml.etree.ElementTree.Element] returned by `makeTag` method or `(None, None, None)`.
"""
text, index, handled = self.getText(data, m.end(0))
if not handled:
return None, None, None
id, end, handled = self.evalId(data, index, text)
if not handled:
return None, None, None
# Clean up line breaks in id
id = self.NEWLINE_CLEANUP_RE.sub(' ', id)
if id not in self.md.references: # ignore undefined refs
return None, m.start(0), end
href, title = self.md.references[id]
return self.makeTag(href, title, text), m.start(0), end
def evalId(self, data, index, text):
"""
Evaluate the id portion of `[ref][id]`.
If `[ref][]` use `[ref]`.
"""
m = self.RE_LINK.match(data, pos=index)
if not m:
return None, index, False
else:
id = m.group(1).lower()
end = m.end(0)
if not id:
id = text.lower()
return id, end, True
def makeTag(self, href: str, title: str, text: str) -> etree.Element:
""" Return an `a` [`Element`][xml.etree.ElementTree.Element]. """
el = etree.Element('a')
el.set('href', href)
if title:
el.set('title', title)
el.text = text
return el
class ShortReferenceInlineProcessor(ReferenceInlineProcessor):
"""Short form of reference: `[google]`. """
def evalId(self, data, index, text):
"""Evaluate the id of `[ref]`. """
return text.lower(), index, True
class ImageReferenceInlineProcessor(ReferenceInlineProcessor):
""" Match to a stored reference and return `img` element. """
def makeTag(self, href: str, title: str, text: str) -> etree.Element:
""" Return an `img` [`Element`][xml.etree.ElementTree.Element]. """
el = etree.Element("img")
el.set("src", href)
if title:
el.set("title", title)
el.set("alt", self.unescape(text))
return el
class ShortImageReferenceInlineProcessor(ImageReferenceInlineProcessor):
""" Short form of image reference: `![ref]`. """
def evalId(self, data, index, text):
"""Evaluate the id of `[ref]`. """
return text.lower(), index, True
class AutolinkInlineProcessor(InlineProcessor):
""" Return a link Element given an auto-link (`<http://example/com>`). """
def handleMatch(self, m: re.Match[str], data: str) -> tuple[etree.Element, int, int]:
""" Return an `a` [`Element`][xml.etree.ElementTree.Element] of `group(1)`. """
el = etree.Element("a")
el.set('href', self.unescape(m.group(1)))
el.text = util.AtomicString(m.group(1))
return el, m.start(0), m.end(0)
class AutomailInlineProcessor(InlineProcessor):
"""
Return a `mailto` link Element given an auto-mail link (`<foo@example.com>`).
"""
def handleMatch(self, m: re.Match[str], data: str) -> tuple[etree.Element, int, int]:
""" Return an [`Element`][xml.etree.ElementTree.Element] containing a `mailto` link of `group(1)`. """
el = etree.Element('a')
email = self.unescape(m.group(1))
if email.startswith("mailto:"):
email = email[len("mailto:"):]
def codepoint2name(code):
"""Return entity definition by code, or the code if not defined."""
entity = entities.codepoint2name.get(code)
if entity:
return "{}{};".format(util.AMP_SUBSTITUTE, entity)
else:
return "%s#%d;" % (util.AMP_SUBSTITUTE, code)
letters = [codepoint2name(ord(letter)) for letter in email]
el.text = util.AtomicString(''.join(letters))
mailto = "mailto:" + email
mailto = "".join([util.AMP_SUBSTITUTE + '#%d;' %
ord(letter) for letter in mailto])
el.set('href', mailto)
return el, m.start(0), m.end(0)

View File

@@ -0,0 +1,143 @@
# Python Markdown
# A Python implementation of John Gruber's Markdown.
# Documentation: https://python-markdown.github.io/
# GitHub: https://github.com/Python-Markdown/markdown/
# PyPI: https://pypi.org/project/Markdown/
# Started by Manfred Stienstra (http://www.dwerg.net/).
# Maintained for a few years by Yuri Takhteyev (http://www.freewisdom.org).
# Currently maintained by Waylan Limberg (https://github.com/waylan),
# Dmitry Shachnev (https://github.com/mitya57) and Isaac Muse (https://github.com/facelessuser).
# Copyright 2007-2023 The Python Markdown Project (v. 1.7 and later)
# Copyright 2004, 2005, 2006 Yuri Takhteyev (v. 0.2-1.6b)
# Copyright 2004 Manfred Stienstra (the original version)
# License: BSD (see LICENSE.md for details).
"""
Post-processors run on the text of the entire document after is has been serialized into a string.
Postprocessors should be used to work with the text just before output. Usually, they are used add
back sections that were extracted in a preprocessor, fix up outgoing encodings, or wrap the whole
document.
"""
from __future__ import annotations
from collections import OrderedDict
from typing import TYPE_CHECKING, Any
from . import util
import re
if TYPE_CHECKING: # pragma: no cover
from markdown import Markdown
def build_postprocessors(md: Markdown, **kwargs: Any) -> util.Registry[Postprocessor]:
""" Build the default postprocessors for Markdown. """
postprocessors = util.Registry()
postprocessors.register(RawHtmlPostprocessor(md), 'raw_html', 30)
postprocessors.register(AndSubstitutePostprocessor(), 'amp_substitute', 20)
return postprocessors
class Postprocessor(util.Processor):
"""
Postprocessors are run after the ElementTree it converted back into text.
Each Postprocessor implements a `run` method that takes a pointer to a
text string, modifies it as necessary and returns a text string.
Postprocessors must extend `Postprocessor`.
"""
def run(self, text: str) -> str:
"""
Subclasses of `Postprocessor` should implement a `run` method, which
takes the html document as a single text string and returns a
(possibly modified) string.
"""
pass # pragma: no cover
class RawHtmlPostprocessor(Postprocessor):
""" Restore raw html to the document. """
BLOCK_LEVEL_REGEX = re.compile(r'^\<\/?([^ >]+)')
def run(self, text: str):
""" Iterate over html stash and restore html. """
replacements = OrderedDict()
for i in range(self.md.htmlStash.html_counter):
html = self.stash_to_string(self.md.htmlStash.rawHtmlBlocks[i])
if self.isblocklevel(html):
replacements["<p>{}</p>".format(
self.md.htmlStash.get_placeholder(i))] = html
replacements[self.md.htmlStash.get_placeholder(i)] = html
def substitute_match(m):
key = m.group(0)
if key not in replacements:
if key[3:-4] in replacements:
return f'<p>{ replacements[key[3:-4]] }</p>'
else:
return key
return replacements[key]
if replacements:
base_placeholder = util.HTML_PLACEHOLDER % r'([0-9]+)'
pattern = re.compile(f'<p>{ base_placeholder }</p>|{ base_placeholder }')
processed_text = pattern.sub(substitute_match, text)
else:
return text
if processed_text == text:
return processed_text
else:
return self.run(processed_text)
def isblocklevel(self, html: str) -> bool:
""" Check is block of HTML is block-level. """
m = self.BLOCK_LEVEL_REGEX.match(html)
if m:
if m.group(1)[0] in ('!', '?', '@', '%'):
# Comment, PHP etc...
return True
return self.md.is_block_level(m.group(1))
return False
def stash_to_string(self, text: str) -> str:
""" Convert a stashed object to a string. """
return str(text)
class AndSubstitutePostprocessor(Postprocessor):
""" Restore valid entities """
def run(self, text):
text = text.replace(util.AMP_SUBSTITUTE, "&")
return text
@util.deprecated(
"This class is deprecated and will be removed in the future; "
"use [`UnescapeTreeprocessor`][markdown.treeprocessors.UnescapeTreeprocessor] instead."
)
class UnescapePostprocessor(Postprocessor):
""" Restore escaped chars. """
RE = re.compile(r'{}(\d+){}'.format(util.STX, util.ETX))
def unescape(self, m):
return chr(int(m.group(1)))
def run(self, text):
return self.RE.sub(self.unescape, text)

View File

@@ -0,0 +1,91 @@
# Python Markdown
# A Python implementation of John Gruber's Markdown.
# Documentation: https://python-markdown.github.io/
# GitHub: https://github.com/Python-Markdown/markdown/
# PyPI: https://pypi.org/project/Markdown/
# Started by Manfred Stienstra (http://www.dwerg.net/).
# Maintained for a few years by Yuri Takhteyev (http://www.freewisdom.org).
# Currently maintained by Waylan Limberg (https://github.com/waylan),
# Dmitry Shachnev (https://github.com/mitya57) and Isaac Muse (https://github.com/facelessuser).
# Copyright 2007-2023 The Python Markdown Project (v. 1.7 and later)
# Copyright 2004, 2005, 2006 Yuri Takhteyev (v. 0.2-1.6b)
# Copyright 2004 Manfred Stienstra (the original version)
# License: BSD (see LICENSE.md for details).
"""
Preprocessors work on source text before it is broken down into its individual parts.
This is an excellent place to clean up bad characters or to extract portions for later
processing that the parser may otherwise choke on.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from . import util
from .htmlparser import HTMLExtractor
import re
if TYPE_CHECKING: # pragma: no cover
from markdown import Markdown
def build_preprocessors(md: Markdown, **kwargs: Any) -> util.Registry[Preprocessor]:
""" Build and return the default set of preprocessors used by Markdown. """
preprocessors = util.Registry()
preprocessors.register(NormalizeWhitespace(md), 'normalize_whitespace', 30)
preprocessors.register(HtmlBlockPreprocessor(md), 'html_block', 20)
return preprocessors
class Preprocessor(util.Processor):
"""
Preprocessors are run after the text is broken into lines.
Each preprocessor implements a `run` method that takes a pointer to a
list of lines of the document, modifies it as necessary and returns
either the same pointer or a pointer to a new list.
Preprocessors must extend `Preprocessor`.
"""
def run(self, lines: list[str]) -> list[str]:
"""
Each subclass of `Preprocessor` should override the `run` method, which
takes the document as a list of strings split by newlines and returns
the (possibly modified) list of lines.
"""
pass # pragma: no cover
class NormalizeWhitespace(Preprocessor):
""" Normalize whitespace for consistent parsing. """
def run(self, lines: list[str]) -> list[str]:
source = '\n'.join(lines)
source = source.replace(util.STX, "").replace(util.ETX, "")
source = source.replace("\r\n", "\n").replace("\r", "\n") + "\n\n"
source = source.expandtabs(self.md.tab_length)
source = re.sub(r'(?<=\n) +\n', '\n', source)
return source.split('\n')
class HtmlBlockPreprocessor(Preprocessor):
"""
Remove html blocks from the text and store them for later retrieval.
The raw HTML is stored in the [`htmlStash`][markdown.util.HtmlStash] of the
[`Markdown`][markdown.Markdown] instance.
"""
def run(self, lines: list[str]) -> list[str]:
source = '\n'.join(lines)
parser = HTMLExtractor(self.md)
parser.feed(source)
parser.close()
return ''.join(parser.cleandoc).split('\n')

View File

@@ -0,0 +1,193 @@
# Add x/html serialization to `Elementree`
# Taken from ElementTree 1.3 preview with slight modifications
#
# Copyright (c) 1999-2007 by Fredrik Lundh. All rights reserved.
#
# fredrik@pythonware.com
# https://www.pythonware.com/
#
# --------------------------------------------------------------------
# The ElementTree toolkit is
#
# Copyright (c) 1999-2007 by Fredrik Lundh
#
# By obtaining, using, and/or copying this software and/or its
# associated documentation, you agree that you have read, understood,
# and will comply with the following terms and conditions:
#
# Permission to use, copy, modify, and distribute this software and
# its associated documentation for any purpose and without fee is
# hereby granted, provided that the above copyright notice appears in
# all copies, and that both that copyright notice and this permission
# notice appear in supporting documentation, and that the name of
# Secret Labs AB or the author not be used in advertising or publicity
# pertaining to distribution of the software without specific, written
# prior permission.
#
# SECRET LABS AB AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD
# TO THIS SOFTWARE, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANT-
# ABILITY AND FITNESS. IN NO EVENT SHALL SECRET LABS AB OR THE AUTHOR
# BE LIABLE FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY
# DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
# WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
# ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
# OF THIS SOFTWARE.
# --------------------------------------------------------------------
"""
Python-Markdown provides two serializers which render [`ElementTree.Element`][xml.etree.ElementTree.Element]
objects to a string of HTML. Both functions wrap the same underlying code with only a few minor
differences as outlined below:
1. Empty (self-closing) tags are rendered as `<tag>` for HTML and as `<tag />` for XHTML.
2. Boolean attributes are rendered as `attrname` for HTML and as `attrname="attrname"` for XHTML.
"""
from __future__ import annotations
from xml.etree.ElementTree import ProcessingInstruction
from xml.etree.ElementTree import Comment, ElementTree, Element, QName, HTML_EMPTY
import re
__all__ = ['to_html_string', 'to_xhtml_string']
RE_AMP = re.compile(r'&(?!(?:\#[0-9]+|\#x[0-9a-f]+|[0-9a-z]+);)', re.I)
def _raise_serialization_error(text): # pragma: no cover
raise TypeError(
"cannot serialize {!r} (type {})".format(text, type(text).__name__)
)
def _escape_cdata(text):
# escape character data
try:
# it's worth avoiding do-nothing calls for strings that are
# shorter than 500 character, or so. assume that's, by far,
# the most common case in most applications.
if "&" in text:
# Only replace & when not part of an entity
text = RE_AMP.sub('&amp;', text)
if "<" in text:
text = text.replace("<", "&lt;")
if ">" in text:
text = text.replace(">", "&gt;")
return text
except (TypeError, AttributeError): # pragma: no cover
_raise_serialization_error(text)
def _escape_attrib(text):
# escape attribute value
try:
if "&" in text:
# Only replace & when not part of an entity
text = RE_AMP.sub('&amp;', text)
if "<" in text:
text = text.replace("<", "&lt;")
if ">" in text:
text = text.replace(">", "&gt;")
if "\"" in text:
text = text.replace("\"", "&quot;")
if "\n" in text:
text = text.replace("\n", "&#10;")
return text
except (TypeError, AttributeError): # pragma: no cover
_raise_serialization_error(text)
def _escape_attrib_html(text):
# escape attribute value
try:
if "&" in text:
# Only replace & when not part of an entity
text = RE_AMP.sub('&amp;', text)
if "<" in text:
text = text.replace("<", "&lt;")
if ">" in text:
text = text.replace(">", "&gt;")
if "\"" in text:
text = text.replace("\"", "&quot;")
return text
except (TypeError, AttributeError): # pragma: no cover
_raise_serialization_error(text)
def _serialize_html(write, elem, format):
tag = elem.tag
text = elem.text
if tag is Comment:
write("<!--%s-->" % _escape_cdata(text))
elif tag is ProcessingInstruction:
write("<?%s?>" % _escape_cdata(text))
elif tag is None:
if text:
write(_escape_cdata(text))
for e in elem:
_serialize_html(write, e, format)
else:
namespace_uri = None
if isinstance(tag, QName):
# `QNAME` objects store their data as a string: `{uri}tag`
if tag.text[:1] == "{":
namespace_uri, tag = tag.text[1:].split("}", 1)
else:
raise ValueError('QName objects must define a tag.')
write("<" + tag)
items = elem.items()
if items:
items = sorted(items) # lexical order
for k, v in items:
if isinstance(k, QName):
# Assume a text only `QName`
k = k.text
if isinstance(v, QName):
# Assume a text only `QName`
v = v.text
else:
v = _escape_attrib_html(v)
if k == v and format == 'html':
# handle boolean attributes
write(" %s" % v)
else:
write(' {}="{}"'.format(k, v))
if namespace_uri:
write(' xmlns="%s"' % (_escape_attrib(namespace_uri)))
if format == "xhtml" and tag.lower() in HTML_EMPTY:
write(" />")
else:
write(">")
if text:
if tag.lower() in ["script", "style"]:
write(text)
else:
write(_escape_cdata(text))
for e in elem:
_serialize_html(write, e, format)
if tag.lower() not in HTML_EMPTY:
write("</" + tag + ">")
if elem.tail:
write(_escape_cdata(elem.tail))
def _write_html(root, format="html"):
assert root is not None
data = []
write = data.append
_serialize_html(write, root, format)
return "".join(data)
# --------------------------------------------------------------------
# public functions
def to_html_string(element: Element) -> str:
""" Serialize element and its children to a string of HTML5. """
return _write_html(ElementTree(element).getroot(), format="html")
def to_xhtml_string(element: Element) -> str:
""" Serialize element and its children to a string of XHTML. """
return _write_html(ElementTree(element).getroot(), format="xhtml")

View File

@@ -0,0 +1,224 @@
# Python Markdown
# A Python implementation of John Gruber's Markdown.
# Documentation: https://python-markdown.github.io/
# GitHub: https://github.com/Python-Markdown/markdown/
# PyPI: https://pypi.org/project/Markdown/
# Started by Manfred Stienstra (http://www.dwerg.net/).
# Maintained for a few years by Yuri Takhteyev (http://www.freewisdom.org).
# Currently maintained by Waylan Limberg (https://github.com/waylan),
# Dmitry Shachnev (https://github.com/mitya57) and Isaac Muse (https://github.com/facelessuser).
# Copyright 2007-2023 The Python Markdown Project (v. 1.7 and later)
# Copyright 2004, 2005, 2006 Yuri Takhteyev (v. 0.2-1.6b)
# Copyright 2004 Manfred Stienstra (the original version)
# License: BSD (see LICENSE.md for details).
""" A collection of tools for testing the Markdown code base and extensions. """
from __future__ import annotations
import os
import sys
import unittest
import textwrap
from typing import Any
from . import markdown, Markdown, util
try:
import tidylib
except ImportError:
tidylib = None
__all__ = ['TestCase', 'LegacyTestCase', 'Kwargs']
class TestCase(unittest.TestCase):
"""
A [`unittest.TestCase`][] subclass with helpers for testing Markdown output.
Define `default_kwargs` as a `dict` of keywords to pass to Markdown for each
test. The defaults can be overridden on individual tests.
The `assertMarkdownRenders` method accepts the source text, the expected
output, and any keywords to pass to Markdown. The `default_kwargs` are used
except where overridden by `kwargs`. The output and expected output are passed
to `TestCase.assertMultiLineEqual`. An `AssertionError` is raised with a diff
if the actual output does not equal the expected output.
The `dedent` method is available to dedent triple-quoted strings if
necessary.
In all other respects, behaves as `unittest.TestCase`.
"""
default_kwargs: dict[str, Any] = {}
""" Default options to pass to Markdown for each test. """
def assertMarkdownRenders(self, source, expected, expected_attrs=None, **kwargs):
"""
Test that source Markdown text renders to expected output with given keywords.
`expected_attrs` accepts a `dict`. Each key should be the name of an attribute
on the `Markdown` instance and the value should be the expected value after
the source text is parsed by Markdown. After the expected output is tested,
the expected value for each attribute is compared against the actual
attribute of the `Markdown` instance using `TestCase.assertEqual`.
"""
expected_attrs = expected_attrs or {}
kws = self.default_kwargs.copy()
kws.update(kwargs)
md = Markdown(**kws)
output = md.convert(source)
self.assertMultiLineEqual(output, expected)
for key, value in expected_attrs.items():
self.assertEqual(getattr(md, key), value)
def dedent(self, text):
"""
Dedent text.
"""
# TODO: If/when actual output ends with a newline, then use:
# return textwrap.dedent(text.strip('/n'))
return textwrap.dedent(text).strip()
class recursionlimit:
"""
A context manager which temporarily modifies the Python recursion limit.
The testing framework, coverage, etc. may add an arbitrary number of levels to the depth. To maintain consistency
in the tests, the current stack depth is determined when called, then added to the provided limit.
Example usage:
``` python
with recursionlimit(20):
# test code here
```
See <https://stackoverflow.com/a/50120316/866026>.
"""
def __init__(self, limit):
self.limit = util._get_stack_depth() + limit
self.old_limit = sys.getrecursionlimit()
def __enter__(self):
sys.setrecursionlimit(self.limit)
def __exit__(self, type, value, tb):
sys.setrecursionlimit(self.old_limit)
#########################
# Legacy Test Framework #
#########################
class Kwargs(dict):
""" A `dict` like class for holding keyword arguments. """
pass
def _normalize_whitespace(text):
""" Normalize whitespace for a string of HTML using `tidylib`. """
output, errors = tidylib.tidy_fragment(text, options={
'drop_empty_paras': 0,
'fix_backslash': 0,
'fix_bad_comments': 0,
'fix_uri': 0,
'join_styles': 0,
'lower_literals': 0,
'merge_divs': 0,
'output_xhtml': 1,
'quote_ampersand': 0,
'newline': 'LF'
})
return output
class LegacyTestMeta(type):
def __new__(cls, name, bases, dct):
def generate_test(infile, outfile, normalize, kwargs):
def test(self):
with open(infile, encoding="utf-8") as f:
input = f.read()
with open(outfile, encoding="utf-8") as f:
# Normalize line endings
# (on Windows, git may have altered line endings).
expected = f.read().replace("\r\n", "\n")
output = markdown(input, **kwargs)
if tidylib and normalize:
try:
expected = _normalize_whitespace(expected)
output = _normalize_whitespace(output)
except OSError:
self.skipTest("Tidylib's c library not available.")
elif normalize:
self.skipTest('Tidylib not available.')
self.assertMultiLineEqual(output, expected)
return test
location = dct.get('location', '')
exclude = dct.get('exclude', [])
normalize = dct.get('normalize', False)
input_ext = dct.get('input_ext', '.txt')
output_ext = dct.get('output_ext', '.html')
kwargs = dct.get('default_kwargs', Kwargs())
if os.path.isdir(location):
for file in os.listdir(location):
infile = os.path.join(location, file)
if os.path.isfile(infile):
tname, ext = os.path.splitext(file)
if ext == input_ext:
outfile = os.path.join(location, tname + output_ext)
tname = tname.replace(' ', '_').replace('-', '_')
kws = kwargs.copy()
if tname in dct:
kws.update(dct[tname])
test_name = 'test_%s' % tname
if tname not in exclude:
dct[test_name] = generate_test(infile, outfile, normalize, kws)
else:
dct[test_name] = unittest.skip('Excluded')(lambda: None)
return type.__new__(cls, name, bases, dct)
class LegacyTestCase(unittest.TestCase, metaclass=LegacyTestMeta):
"""
A [`unittest.TestCase`][] subclass for running Markdown's legacy file-based tests.
A subclass should define various properties which point to a directory of
text-based test files and define various behaviors/defaults for those tests.
The following properties are supported:
Attributes:
location (str): A path to the directory of test files. An absolute path is preferred.
exclude (list[str]): A list of tests to exclude. Each test name should comprise the filename
without an extension.
normalize (bool): A boolean value indicating if the HTML should be normalized. Default: `False`.
input_ext (str): A string containing the file extension of input files. Default: `.txt`.
output_ext (str): A string containing the file extension of expected output files. Default: `html`.
default_kwargs (Kwargs[str, Any]): The default set of keyword arguments for all test files in the directory.
In addition, properties can be defined for each individual set of test files within
the directory. The property should be given the name of the file without the file
extension. Any spaces and dashes in the filename should be replaced with
underscores. The value of the property should be a `Kwargs` instance which
contains the keyword arguments that should be passed to `Markdown` for that
test file. The keyword arguments will "update" the `default_kwargs`.
When the class instance is created, it will walk the given directory and create
a separate `Unitttest` for each set of test files using the naming scheme:
`test_filename`. One `Unittest` will be run for each set of input and output files.
"""
pass

View File

@@ -0,0 +1,476 @@
# Python Markdown
# A Python implementation of John Gruber's Markdown.
# Documentation: https://python-markdown.github.io/
# GitHub: https://github.com/Python-Markdown/markdown/
# PyPI: https://pypi.org/project/Markdown/
# Started by Manfred Stienstra (http://www.dwerg.net/).
# Maintained for a few years by Yuri Takhteyev (http://www.freewisdom.org).
# Currently maintained by Waylan Limberg (https://github.com/waylan),
# Dmitry Shachnev (https://github.com/mitya57) and Isaac Muse (https://github.com/facelessuser).
# Copyright 2007-2023 The Python Markdown Project (v. 1.7 and later)
# Copyright 2004, 2005, 2006 Yuri Takhteyev (v. 0.2-1.6b)
# Copyright 2004 Manfred Stienstra (the original version)
# License: BSD (see LICENSE.md for details).
"""
Tree processors manipulate the tree created by block processors. They can even create an entirely
new `ElementTree` object. This is an excellent place for creating summaries, adding collected
references, or last minute adjustments.
"""
from __future__ import annotations
import re
import xml.etree.ElementTree as etree
from typing import TYPE_CHECKING, Any
from . import util
from . import inlinepatterns
if TYPE_CHECKING: # pragma: no cover
from markdown import Markdown
def build_treeprocessors(md: Markdown, **kwargs: Any) -> util.Registry[Treeprocessor]:
""" Build the default `treeprocessors` for Markdown. """
treeprocessors = util.Registry()
treeprocessors.register(InlineProcessor(md), 'inline', 20)
treeprocessors.register(PrettifyTreeprocessor(md), 'prettify', 10)
treeprocessors.register(UnescapeTreeprocessor(md), 'unescape', 0)
return treeprocessors
def isString(s: Any) -> bool:
""" Return `True` if object is a string but not an [`AtomicString`][markdown.util.AtomicString]. """
if not isinstance(s, util.AtomicString):
return isinstance(s, str)
return False
class Treeprocessor(util.Processor):
"""
`Treeprocessor`s are run on the `ElementTree` object before serialization.
Each `Treeprocessor` implements a `run` method that takes a pointer to an
`Element` and modifies it as necessary.
`Treeprocessors` must extend `markdown.Treeprocessor`.
"""
def run(self, root: etree.Element) -> etree.Element | None:
"""
Subclasses of `Treeprocessor` should implement a `run` method, which
takes a root `Element`. This method can return another `Element`
object, and the existing root `Element` will be replaced, or it can
modify the current tree and return `None`.
"""
pass # pragma: no cover
class InlineProcessor(Treeprocessor):
"""
A `Treeprocessor` that traverses a tree, applying inline patterns.
"""
def __init__(self, md):
self.__placeholder_prefix = util.INLINE_PLACEHOLDER_PREFIX
self.__placeholder_suffix = util.ETX
self.__placeholder_length = 4 + len(self.__placeholder_prefix) \
+ len(self.__placeholder_suffix)
self.__placeholder_re = util.INLINE_PLACEHOLDER_RE
self.md = md
self.inlinePatterns = md.inlinePatterns
self.ancestors = []
def __makePlaceholder(self, type) -> tuple[str, str]:
""" Generate a placeholder """
id = "%04d" % len(self.stashed_nodes)
hash = util.INLINE_PLACEHOLDER % id
return hash, id
def __findPlaceholder(self, data: str, index: int) -> tuple[str | None, int]:
"""
Extract id from data string, start from index.
Arguments:
data: String.
index: Index, from which we start search.
Returns:
Placeholder id and string index, after the found placeholder.
"""
m = self.__placeholder_re.search(data, index)
if m:
return m.group(1), m.end()
else:
return None, index + 1
def __stashNode(self, node, type) -> str:
""" Add node to stash. """
placeholder, id = self.__makePlaceholder(type)
self.stashed_nodes[id] = node
return placeholder
def __handleInline(self, data: str, patternIndex: int = 0) -> str:
"""
Process string with inline patterns and replace it with placeholders.
Arguments:
data: A line of Markdown text.
patternIndex: The index of the `inlinePattern` to start with.
Returns:
String with placeholders.
"""
if not isinstance(data, util.AtomicString):
startIndex = 0
count = len(self.inlinePatterns)
while patternIndex < count:
data, matched, startIndex = self.__applyPattern(
self.inlinePatterns[patternIndex], data, patternIndex, startIndex
)
if not matched:
patternIndex += 1
return data
def __processElementText(self, node: etree.Element, subnode: etree.Element, isText: bool = True):
"""
Process placeholders in `Element.text` or `Element.tail`
of Elements popped from `self.stashed_nodes`.
Arguments:
node: Parent node.
subnode: Processing node.
isText: Boolean variable, True - it's text, False - it's a tail.
"""
if isText:
text = subnode.text
subnode.text = None
else:
text = subnode.tail
subnode.tail = None
childResult = self.__processPlaceholders(text, subnode, isText)
if not isText and node is not subnode:
pos = list(node).index(subnode) + 1
else:
pos = 0
childResult.reverse()
for newChild in childResult:
node.insert(pos, newChild[0])
def __processPlaceholders(
self,
data: str,
parent: etree.Element,
isText: bool = True
) -> list[tuple[etree.Element, Any]]:
"""
Process string with placeholders and generate `ElementTree` tree.
Arguments:
data: String with placeholders instead of `ElementTree` elements.
parent: Element, which contains processing inline data.
isText: Boolean variable, True - it's text, False - it's a tail.
Returns:
List with `ElementTree` elements with applied inline patterns.
"""
def linkText(text):
if text:
if result:
if result[-1][0].tail:
result[-1][0].tail += text
else:
result[-1][0].tail = text
elif not isText:
if parent.tail:
parent.tail += text
else:
parent.tail = text
else:
if parent.text:
parent.text += text
else:
parent.text = text
result = []
strartIndex = 0
while data:
index = data.find(self.__placeholder_prefix, strartIndex)
if index != -1:
id, phEndIndex = self.__findPlaceholder(data, index)
if id in self.stashed_nodes:
node = self.stashed_nodes.get(id)
if index > 0:
text = data[strartIndex:index]
linkText(text)
if not isString(node): # it's Element
for child in [node] + list(node):
if child.tail:
if child.tail.strip():
self.__processElementText(
node, child, False
)
if child.text:
if child.text.strip():
self.__processElementText(child, child)
else: # it's just a string
linkText(node)
strartIndex = phEndIndex
continue
strartIndex = phEndIndex
result.append((node, self.ancestors[:]))
else: # wrong placeholder
end = index + len(self.__placeholder_prefix)
linkText(data[strartIndex:end])
strartIndex = end
else:
text = data[strartIndex:]
if isinstance(data, util.AtomicString):
# We don't want to loose the `AtomicString`
text = util.AtomicString(text)
linkText(text)
data = ""
return result
def __applyPattern(
self,
pattern: inlinepatterns.Pattern,
data: str,
patternIndex: int,
startIndex: int = 0
) -> tuple[str, bool, int]:
"""
Check if the line fits the pattern, create the necessary
elements, add it to `stashed_nodes`.
Arguments:
data: The text to be processed.
pattern: The pattern to be checked.
patternIndex: Index of current pattern.
startIndex: String index, from which we start searching.
Returns:
String with placeholders instead of `ElementTree` elements.
"""
new_style = isinstance(pattern, inlinepatterns.InlineProcessor)
for exclude in pattern.ANCESTOR_EXCLUDES:
if exclude.lower() in self.ancestors:
return data, False, 0
if new_style:
match = None
# Since `handleMatch` may reject our first match,
# we iterate over the buffer looking for matches
# until we can't find any more.
for match in pattern.getCompiledRegExp().finditer(data, startIndex):
node, start, end = pattern.handleMatch(match, data)
if start is None or end is None:
startIndex += match.end(0)
match = None
continue
break
else: # pragma: no cover
match = pattern.getCompiledRegExp().match(data[startIndex:])
leftData = data[:startIndex]
if not match:
return data, False, 0
if not new_style: # pragma: no cover
node = pattern.handleMatch(match)
start = match.start(0)
end = match.end(0)
if node is None:
return data, True, end
if not isString(node):
if not isinstance(node.text, util.AtomicString):
# We need to process current node too
for child in [node] + list(node):
if not isString(node):
if child.text:
self.ancestors.append(child.tag.lower())
child.text = self.__handleInline(
child.text, patternIndex + 1
)
self.ancestors.pop()
if child.tail:
child.tail = self.__handleInline(
child.tail, patternIndex
)
placeholder = self.__stashNode(node, pattern.type())
if new_style:
return "{}{}{}".format(data[:start],
placeholder, data[end:]), True, 0
else: # pragma: no cover
return "{}{}{}{}".format(leftData,
match.group(1),
placeholder, match.groups()[-1]), True, 0
def __build_ancestors(self, parent, parents):
"""Build the ancestor list."""
ancestors = []
while parent is not None:
if parent is not None:
ancestors.append(parent.tag.lower())
parent = self.parent_map.get(parent)
ancestors.reverse()
parents.extend(ancestors)
def run(self, tree: etree.Element, ancestors: list[str] | None = None) -> etree.Element:
"""Apply inline patterns to a parsed Markdown tree.
Iterate over `Element`, find elements with inline tag, apply inline
patterns and append newly created Elements to tree. To avoid further
processing of string with inline patterns, instead of normal string,
use subclass [`AtomicString`][markdown.util.AtomicString]:
node.text = markdown.util.AtomicString("This will not be processed.")
Arguments:
tree: `Element` object, representing Markdown tree.
ancestors: List of parent tag names that precede the tree node (if needed).
Returns:
An element tree object with applied inline patterns.
"""
self.stashed_nodes: dict[str, etree.Element] = {}
# Ensure a valid parent list, but copy passed in lists
# to ensure we don't have the user accidentally change it on us.
tree_parents = [] if ancestors is None else ancestors[:]
self.parent_map = {c: p for p in tree.iter() for c in p}
stack = [(tree, tree_parents)]
while stack:
currElement, parents = stack.pop()
self.ancestors = parents
self.__build_ancestors(currElement, self.ancestors)
insertQueue = []
for child in currElement:
if child.text and not isinstance(
child.text, util.AtomicString
):
self.ancestors.append(child.tag.lower())
text = child.text
child.text = None
lst = self.__processPlaceholders(
self.__handleInline(text), child
)
for item in lst:
self.parent_map[item[0]] = child
stack += lst
insertQueue.append((child, lst))
self.ancestors.pop()
if child.tail:
tail = self.__handleInline(child.tail)
dumby = etree.Element('d')
child.tail = None
tailResult = self.__processPlaceholders(tail, dumby, False)
if dumby.tail:
child.tail = dumby.tail
pos = list(currElement).index(child) + 1
tailResult.reverse()
for newChild in tailResult:
self.parent_map[newChild[0]] = currElement
currElement.insert(pos, newChild[0])
if len(child):
self.parent_map[child] = currElement
stack.append((child, self.ancestors[:]))
for element, lst in insertQueue:
for i, obj in enumerate(lst):
newChild = obj[0]
element.insert(i, newChild)
return tree
class PrettifyTreeprocessor(Treeprocessor):
""" Add line breaks to the html document. """
def _prettifyETree(self, elem):
""" Recursively add line breaks to `ElementTree` children. """
i = "\n"
if self.md.is_block_level(elem.tag) and elem.tag not in ['code', 'pre']:
if (not elem.text or not elem.text.strip()) \
and len(elem) and self.md.is_block_level(elem[0].tag):
elem.text = i
for e in elem:
if self.md.is_block_level(e.tag):
self._prettifyETree(e)
if not elem.tail or not elem.tail.strip():
elem.tail = i
def run(self, root: etree.Element) -> None:
""" Add line breaks to `Element` object and its children. """
self._prettifyETree(root)
# Do `<br />`'s separately as they are often in the middle of
# inline content and missed by `_prettifyETree`.
brs = root.iter('br')
for br in brs:
if not br.tail or not br.tail.strip():
br.tail = '\n'
else:
br.tail = '\n%s' % br.tail
# Clean up extra empty lines at end of code blocks.
pres = root.iter('pre')
for pre in pres:
if len(pre) and pre[0].tag == 'code':
code = pre[0]
# Only prettify code containing text only
if not len(code) and code.text is not None:
code.text = util.AtomicString(code.text.rstrip() + '\n')
class UnescapeTreeprocessor(Treeprocessor):
""" Restore escaped chars """
RE = re.compile(r'{}(\d+){}'.format(util.STX, util.ETX))
def _unescape(self, m):
return chr(int(m.group(1)))
def unescape(self, text: str) -> str:
return self.RE.sub(self._unescape, text)
def run(self, root):
""" Loop over all elements and unescape all text. """
for elem in root.iter():
# Unescape text content
if elem.text and not elem.tag == 'code':
elem.text = self.unescape(elem.text)
# Unescape tail content
if elem.tail:
elem.tail = self.unescape(elem.tail)
# Unescape attribute values
for key, value in elem.items():
elem.set(key, self.unescape(value))

View File

@@ -0,0 +1,399 @@
# Python Markdown
# A Python implementation of John Gruber's Markdown.
# Documentation: https://python-markdown.github.io/
# GitHub: https://github.com/Python-Markdown/markdown/
# PyPI: https://pypi.org/project/Markdown/
# Started by Manfred Stienstra (http://www.dwerg.net/).
# Maintained for a few years by Yuri Takhteyev (http://www.freewisdom.org).
# Currently maintained by Waylan Limberg (https://github.com/waylan),
# Dmitry Shachnev (https://github.com/mitya57) and Isaac Muse (https://github.com/facelessuser).
# Copyright 2007-2023 The Python Markdown Project (v. 1.7 and later)
# Copyright 2004, 2005, 2006 Yuri Takhteyev (v. 0.2-1.6b)
# Copyright 2004 Manfred Stienstra (the original version)
# License: BSD (see LICENSE.md for details).
"""
This module contains various contacts, classes and functions which get referenced and used
throughout the code base.
"""
from __future__ import annotations
import re
import sys
import warnings
from functools import wraps, lru_cache
from itertools import count
from typing import TYPE_CHECKING, Generic, Iterator, NamedTuple, TypeVar, overload
if TYPE_CHECKING: # pragma: no cover
from markdown import Markdown
_T = TypeVar('_T')
"""
Constants you might want to modify
-----------------------------------------------------------------------------
"""
BLOCK_LEVEL_ELEMENTS: list[str] = [
# Elements which are invalid to wrap in a `<p>` tag.
# See https://w3c.github.io/html/grouping-content.html#the-p-element
'address', 'article', 'aside', 'blockquote', 'details', 'div', 'dl',
'fieldset', 'figcaption', 'figure', 'footer', 'form', 'h1', 'h2', 'h3',
'h4', 'h5', 'h6', 'header', 'hgroup', 'hr', 'main', 'menu', 'nav', 'ol',
'p', 'pre', 'section', 'table', 'ul',
# Other elements which Markdown should not be mucking up the contents of.
'canvas', 'colgroup', 'dd', 'body', 'dt', 'group', 'html', 'iframe', 'li', 'legend',
'math', 'map', 'noscript', 'output', 'object', 'option', 'progress', 'script',
'style', 'summary', 'tbody', 'td', 'textarea', 'tfoot', 'th', 'thead', 'tr', 'video'
]
"""
List of HTML tags which get treated as block-level elements. Same as the `block_level_elements`
attribute of the [`Markdown`][markdown.Markdown] class. Generally one should use the
attribute on the class. This remains for compatibility with older extensions.
"""
# Placeholders
STX = '\u0002'
""" "Start of Text" marker for placeholder templates. """
ETX = '\u0003'
""" "End of Text" marker for placeholder templates. """
INLINE_PLACEHOLDER_PREFIX = STX+"klzzwxh:"
""" Prefix for inline placeholder template. """
INLINE_PLACEHOLDER = INLINE_PLACEHOLDER_PREFIX + "%s" + ETX
""" Placeholder template for stashed inline text. """
INLINE_PLACEHOLDER_RE = re.compile(INLINE_PLACEHOLDER % r'([0-9]+)')
""" Regular Expression which matches inline placeholders. """
AMP_SUBSTITUTE = STX+"amp"+ETX
""" Placeholder template for HTML entities. """
HTML_PLACEHOLDER = STX + "wzxhzdk:%s" + ETX
""" Placeholder template for raw HTML. """
HTML_PLACEHOLDER_RE = re.compile(HTML_PLACEHOLDER % r'([0-9]+)')
""" Regular expression which matches HTML placeholders. """
TAG_PLACEHOLDER = STX + "hzzhzkh:%s" + ETX
""" Placeholder template for tags. """
# Constants you probably do not need to change
# -----------------------------------------------------------------------------
RTL_BIDI_RANGES = (
('\u0590', '\u07FF'),
# Hebrew (0590-05FF), Arabic (0600-06FF),
# Syriac (0700-074F), Arabic supplement (0750-077F),
# Thaana (0780-07BF), Nko (07C0-07FF).
('\u2D30', '\u2D7F') # Tifinagh
)
# AUXILIARY GLOBAL FUNCTIONS
# =============================================================================
@lru_cache(maxsize=None)
def get_installed_extensions():
""" Return all entry_points in the `markdown.extensions` group. """
if sys.version_info >= (3, 10):
from importlib import metadata
else: # `<PY310` use backport
import importlib_metadata as metadata
# Only load extension entry_points once.
return metadata.entry_points(group='markdown.extensions')
def deprecated(message: str, stacklevel: int = 2):
"""
Raise a [`DeprecationWarning`][] when wrapped function/method is called.
Usage:
```python
@deprecated("This method will be removed in version X; use Y instead.")
def some_method():
pass
```
"""
def wrapper(func):
@wraps(func)
def deprecated_func(*args, **kwargs):
warnings.warn(
f"'{func.__name__}' is deprecated. {message}",
category=DeprecationWarning,
stacklevel=stacklevel
)
return func(*args, **kwargs)
return deprecated_func
return wrapper
def parseBoolValue(value: str | None, fail_on_errors: bool = True, preserve_none: bool = False) -> bool | None:
"""Parses a string representing a boolean value. If parsing was successful,
returns `True` or `False`. If `preserve_none=True`, returns `True`, `False`,
or `None`. If parsing was not successful, raises `ValueError`, or, if
`fail_on_errors=False`, returns `None`."""
if not isinstance(value, str):
if preserve_none and value is None:
return value
return bool(value)
elif preserve_none and value.lower() == 'none':
return None
elif value.lower() in ('true', 'yes', 'y', 'on', '1'):
return True
elif value.lower() in ('false', 'no', 'n', 'off', '0', 'none'):
return False
elif fail_on_errors:
raise ValueError('Cannot parse bool value: %r' % value)
def code_escape(text: str) -> str:
"""HTML escape a string of code."""
if "&" in text:
text = text.replace("&", "&amp;")
if "<" in text:
text = text.replace("<", "&lt;")
if ">" in text:
text = text.replace(">", "&gt;")
return text
def _get_stack_depth(size=2):
"""Get current stack depth, performantly.
"""
frame = sys._getframe(size)
for size in count(size):
frame = frame.f_back
if not frame:
return size
def nearing_recursion_limit() -> bool:
"""Return true if current stack depth is within 100 of maximum limit."""
return sys.getrecursionlimit() - _get_stack_depth() < 100
# MISC AUXILIARY CLASSES
# =============================================================================
class AtomicString(str):
"""A string which should not be further processed."""
pass
class Processor:
""" The base class for all processors.
Attributes:
Processor.md: The `Markdown` instance passed in an initialization.
Arguments:
md: The `Markdown` instance this processor is a part of.
"""
def __init__(self, md: Markdown | None = None):
self.md = md
class HtmlStash:
"""
This class is used for stashing HTML objects that we extract
in the beginning and replace with place-holders.
"""
def __init__(self):
""" Create an `HtmlStash`. """
self.html_counter = 0 # for counting inline html segments
self.rawHtmlBlocks = []
self.tag_counter = 0
self.tag_data = [] # list of dictionaries in the order tags appear
def store(self, html: str) -> str:
"""
Saves an HTML segment for later reinsertion. Returns a
placeholder string that needs to be inserted into the
document.
Keyword arguments:
html: An html segment.
Returns:
A placeholder string.
"""
self.rawHtmlBlocks.append(html)
placeholder = self.get_placeholder(self.html_counter)
self.html_counter += 1
return placeholder
def reset(self) -> None:
""" Clear the stash. """
self.html_counter = 0
self.rawHtmlBlocks = []
def get_placeholder(self, key: int) -> str:
return HTML_PLACEHOLDER % key
def store_tag(self, tag: str, attrs: list, left_index: int, right_index: int) -> str:
"""Store tag data and return a placeholder."""
self.tag_data.append({'tag': tag, 'attrs': attrs,
'left_index': left_index,
'right_index': right_index})
placeholder = TAG_PLACEHOLDER % str(self.tag_counter)
self.tag_counter += 1 # equal to the tag's index in `self.tag_data`
return placeholder
# Used internally by `Registry` for each item in its sorted list.
# Provides an easier to read API when editing the code later.
# For example, `item.name` is more clear than `item[0]`.
class _PriorityItem(NamedTuple):
name: str
priority: float
class Registry(Generic[_T]):
"""
A priority sorted registry.
A `Registry` instance provides two public methods to alter the data of the
registry: `register` and `deregister`. Use `register` to add items and
`deregister` to remove items. See each method for specifics.
When registering an item, a "name" and a "priority" must be provided. All
items are automatically sorted by "priority" from highest to lowest. The
"name" is used to remove ("deregister") and get items.
A `Registry` instance it like a list (which maintains order) when reading
data. You may iterate over the items, get an item and get a count (length)
of all items. You may also check that the registry contains an item.
When getting an item you may use either the index of the item or the
string-based "name". For example:
registry = Registry()
registry.register(SomeItem(), 'itemname', 20)
# Get the item by index
item = registry[0]
# Get the item by name
item = registry['itemname']
When checking that the registry contains an item, you may use either the
string-based "name", or a reference to the actual item. For example:
someitem = SomeItem()
registry.register(someitem, 'itemname', 20)
# Contains the name
assert 'itemname' in registry
# Contains the item instance
assert someitem in registry
The method `get_index_for_name` is also available to obtain the index of
an item using that item's assigned "name".
"""
def __init__(self):
self._data: dict[str, _T] = {}
self._priority = []
self._is_sorted = False
def __contains__(self, item: str | _T) -> bool:
if isinstance(item, str):
# Check if an item exists by this name.
return item in self._data.keys()
# Check if this instance exists.
return item in self._data.values()
def __iter__(self) -> Iterator[_T]:
self._sort()
return iter([self._data[k] for k, p in self._priority])
@overload
def __getitem__(self, key: str | int) -> _T: # pragma: no cover
...
@overload
def __getitem__(self, key: slice) -> Registry[_T]: # pragma: no cover
...
def __getitem__(self, key: str | int | slice) -> _T | Registry[_T]:
self._sort()
if isinstance(key, slice):
data: Registry[_T] = Registry()
for k, p in self._priority[key]:
data.register(self._data[k], k, p)
return data
if isinstance(key, int):
return self._data[self._priority[key].name]
return self._data[key]
def __len__(self) -> int:
return len(self._priority)
def __repr__(self):
return '<{}({})>'.format(self.__class__.__name__, list(self))
def get_index_for_name(self, name: str) -> int:
"""
Return the index of the given name.
"""
if name in self:
self._sort()
return self._priority.index(
[x for x in self._priority if x.name == name][0]
)
raise ValueError('No item named "{}" exists.'.format(name))
def register(self, item: _T, name: str, priority: float) -> None:
"""
Add an item to the registry with the given name and priority.
Arguments:
item: The item being registered.
name: A string used to reference the item.
priority: An integer or float used to sort against all items.
If an item is registered with a "name" which already exists, the
existing item is replaced with the new item. Treat carefully as the
old item is lost with no way to recover it. The new item will be
sorted according to its priority and will **not** retain the position
of the old item.
"""
if name in self:
# Remove existing item of same name first
self.deregister(name)
self._is_sorted = False
self._data[name] = item
self._priority.append(_PriorityItem(name, priority))
def deregister(self, name: str, strict: bool = True) -> None:
"""
Remove an item from the registry.
Set `strict=False` to fail silently. Otherwise a [`ValueError`][] is raised for an unknown `name`.
"""
try:
index = self.get_index_for_name(name)
del self._priority[index]
del self._data[name]
except ValueError:
if strict:
raise
def _sort(self):
"""
Sort the registry by priority from highest to lowest.
This method is called internally and should never be explicitly called.
"""
if not self._is_sorted:
self._priority.sort(key=lambda item: item.priority, reverse=True)
self._is_sorted = True

View File

@@ -0,0 +1,102 @@
# Python imports
# Lib imports
import gi
gi.require_version('Gtk', '3.0')
gi.require_version('Gdk', '3.0')
from gi.repository import Gtk
from gi.repository import Gdk
# Application imports
from core.widgets.webkit.webkit_ui import WebkitUI
from .mixins.markdown_preview_mixin import MarkdownPreviewMixin
class MarkdownPreview(Gtk.Popover, MarkdownPreviewMixin):
def __init__(self):
super(MarkdownPreview, self).__init__()
self.can_hide: bool = True
self.fpath: str = ""
self.is_preview_paused: bool = True # True by default b/c started hidden
self._setup_styling()
self._setup_signals()
self._load_widgets()
def _setup_styling(self):
ctx = self.get_style_context()
ctx.add_class("markdown-preview")
self.set_modal(False)
self.set_can_focus(False)
self.set_transitions_enabled(False)
self.set_size_request(480, 720)
self.override_background_color(
Gtk.StateFlags.NORMAL,
Gdk.RGBA(0, 0, 0, 0.0)
)
self.set_constrain_to(
Gtk.PopoverConstraint.WINDOW
)
def _setup_signals(self):
self.connect("hide", self._handle_hide)
self.connect("show", self._handle_show)
def _load_widgets(self):
box = Gtk.Box()
bttn_box = Gtk.ButtonBox()
scrolled_win = Gtk.ScrolledWindow()
viewport = Gtk.Viewport()
self._markdown_view = WebkitUI()
self.start_stop_bttn = Gtk.ToggleButton()
settings_bttn = Gtk.Button()
self.start_stop_bttn.set_label("gtk-media-pause")
self.start_stop_bttn.set_use_stock(True)
settings_bttn.set_image(
Gtk.Image.new_from_stock(
"gtk-justify-fill", Gtk.IconSize.BUTTON
)
)
self._markdown_view.set_vexpand(True)
box.set_orientation(Gtk.Orientation.VERTICAL)
self.start_stop_bttn.connect("clicked", self._tggle_preview_updates)
settings_bttn.connect("clicked", self._handle_settings)
bttn_box.pack_end(self.start_stop_bttn, expand = False, fill = False, padding = 1)
bttn_box.pack_end(settings_bttn, expand = False, fill = False, padding = 1)
viewport.add(self._markdown_view)
scrolled_win.add(viewport)
box.add(bttn_box)
box.add(scrolled_win)
self.add(box)
box.show_all()
def _handle_hide(self, widget):
if self.can_hide:
self.is_preview_paused = True
return False
return True
def _handle_show(self, widget):
self.can_hide = False
self.is_preview_paused = self.start_stop_bttn.get_active()
def _tggle_preview_updates(self, widget):
self.is_preview_paused = not self.is_preview_paused
def _handle_settings(self, widget):
...

View File

@@ -0,0 +1,3 @@
"""
Pligin Module Mixins
"""

View File

@@ -0,0 +1,39 @@
# Python imports
from pathlib import Path
# Lib imports
# Application imports
from .markdown_template_mixin import MarkdownTemplateMixin
from .. import markdown
class MarkdownPreviewMixin(MarkdownTemplateMixin):
def _do_markdown_translate(self, buffer):
if self.is_preview_paused: return
if not self.is_markdown(buffer):
data = self.wrap_html_to_body("<h1>Not a Markdown file...</h1>")
self._load_html(data)
return
data = self.get_rendered_markdown(buffer)
self._load_html(data, f"file://{self.fpath}")
def _load_html(self, data: str, base_path: str = ""):
self._markdown_view.load_html(
content = data, base_uri = base_path
)
def get_rendered_markdown(self, buffer) -> str:
start_itr = buffer.get_start_iter()
end_itr = buffer.get_end_iter()
text = buffer.get_text(start_itr, end_itr, include_hidden_chars = False)
html = markdown.markdown(text)
return self.wrap_html_to_body(html)
def is_markdown(self, buffer) -> bool:
return buffer.get_language() and buffer.get_language().get_id() == "markdown"

View File

@@ -0,0 +1,43 @@
# Python imports
# Lib imports
# Application imports
class MarkdownTemplateMixin:
def wrap_html_to_body(self, html: str):
return f"""\
<!DOCTYPE html>
<html lang="en" dir="ltr">
<head>
<meta charset="utf-8">
<title>Markdown View</title>
<style media="screen">
html, body {{
display: block;
// background-color: #32383e64;
background-color: #32383e;
color: #ffffff;
text-wrap: wrap;
}}
img {{
width: 100%;
height: auto;
}}
code {{
border: 1px solid #32383e;
background-color: #32383e;
padding: 4px;
}}
</style>
</head>
<body>
{html}
</body>
</html>
"""

View File

@@ -0,0 +1,70 @@
# Python imports
# Lib imports
# Application imports
from libs.event_factory import Event_Factory, Code_Event_Types
from plugins.plugin_types import PluginCode
from .markdown_preview import MarkdownPreview
markdown_preview = MarkdownPreview()
class Plugin(PluginCode):
def __init__(self):
super(Plugin, self).__init__()
def _set_file(self, buffer):
event = Event_Factory.create_event(
"get_file", buffer = buffer
)
self.emit_to("files", event)
if not event.response or not event.response.get_location(): return
markdown_preview.fpath = event.response.get_location().get_path()
def _controller_message(self, event: Code_Event_Types.CodeEvent):
if isinstance(event, Code_Event_Types.FocusedViewEvent):
self._set_file(event.view.get_buffer())
markdown_preview._do_markdown_translate(event.view.get_buffer())
elif isinstance(event, Code_Event_Types.TextChangedEvent):
self._set_file(event.buffer)
markdown_preview._do_markdown_translate(event.buffer)
def load(self):
separator_right = self.request_ui_element("separator-right")
markdown_preview.set_relative_to(separator_right)
event = Event_Factory.create_event("register_command",
command_name = "tggle_markdown_preview",
command = Handler,
binding_mode = "released",
binding = "<Shift><Control>m"
)
self.emit_to("source_views", event)
def run(self):
...
class Handler:
@staticmethod
def execute(
view: any,
*args,
**kwargs
):
logger.debug("Command: Markdown Preview")
if not markdown_preview.can_hide:
markdown_preview.can_hide = True
markdown_preview.popdown() if markdown_preview.is_visible() else markdown_preview.popup()

View File

@@ -0,0 +1,3 @@
"""
Pligin Module
"""

View File

@@ -0,0 +1,3 @@
"""
Pligin Package
"""

Binary file not shown.

After

Width:  |  Height:  |  Size: 6.1 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 6.5 KiB

Some files were not shown because too many files have changed in this diff Show More