Compare commits

...

4 Commits

Author SHA1 Message Date
d1643091c3 refactor: remove LSPServerEventsMixin and clean up websocket tests
- Delete unused websocket library test files
- Remove LSPServerEventsMixin and inline its methods into response handlers
- Clean up unused imports (ThreadPoolExecutor, ABC, LSP message structs)
2026-03-15 03:35:18 -05:00
a4ef662da7 refactor(lsp-manager): replace handler architecture with response registry and modular providers
* Remove legacy handlers system (BaseHandler, DefaultHandler, JavaHandler, PythonHandler, HandlerRegistry)
* Introduce response_handlers module with ResponseRegistry for LSP response routing
* Replace HandlerRegistry usage in LSPManager with ResponseRegistry
* Convert LSPManagerUI client lifecycle to GObject signals
* Remove GLib.idle_add usage in LSP client event dispatch
* Move completion request logic into LSPServerEventsMixin
* Replace provider.py and provider_response_cache.py with modular provider/ package
* Simplify plugin wiring via response_registry.set_event_hub()

This refactor decouples response handling, simplifies event flow, and prepares
the LSP manager for easier language-specific extensions.
2026-03-15 01:50:23 -05:00
6cb66985aa refactor(lsp): replace controller layer with client module and LSPManager orchestration
* Rename legacy controller subsystem (LSPController, websocket controller, controller events, and base classes)
    into clarified client module for LSP communication
* Structure around LSPManager and LSPManagerClient to handle orchestration and client lifecycle
* Update plugin integration to use LSPManager instead of LSPController
* Simplify architecture by reducing controller indirection
2026-03-12 00:04:08 -05:00
52db0b8a31 refactor(lsp): restructure lsp plugin controller architecture and simplify provider cache
- Replace LSPManager usage with LSPController integration
- Move UI access through lsp_controller.lsp_manager_ui
- Remove legacy ProviderResponseCache client management
- Simplify completion filtering and matcher handling
- Improve typing annotations and modernize union syntax
- Clean up unused imports and dead code
- Fix completion item parsing for insertText/textEdit fallbacks
- Add async-safe scrolling via GLib.idle_add
2026-03-11 23:15:19 -05:00
38 changed files with 818 additions and 2433 deletions

View File

@@ -0,0 +1,3 @@
"""
LSP Clients Module
"""

View File

@@ -1,5 +1,7 @@
# Python imports
import threading
from os import path
import json
# Lib imports
import gi
@@ -8,29 +10,27 @@ from gi.repository import GLib
# Application imports
from libs.dto.code.lsp.lsp_messages import get_message_str
from libs.dto.code.lsp.lsp_message_structs import LSPResponseTypes, ClientRequest, ClientNotification
from .lsp_controller_websocket import LSPControllerWebsocket
from .lsp_client_websocket import LSPClientWebsocket
class LSPController(LSPControllerWebsocket):
class LSPClient(LSPClientWebsocket):
def __init__(self):
super(LSPController, self).__init__()
super(LSPClient, self).__init__()
# https://github.com/microsoft/multilspy/tree/main/src/multilspy/language_servers
# initialize-params-slim.json was created off of jedi_language_server one
# self._init_params = settings_manager.get_lsp_init_data()
self._language: str = ""
self._init_params: dict = {}
self._event_history: dict[str] = {}
self._language: str = ""
self._init_params: dict = {}
self._event_history: dict[int, str] = {}
try:
from os import path
import json
_USER_HOME = path.expanduser('~')
_SCRIPT_PTH = path.dirname( path.realpath(__file__) )
_LSP_INIT_CONFIG = f"{_SCRIPT_PTH}/../configs/initialize-params-slim.json"
with open(_LSP_INIT_CONFIG) as file:
data = file.read().replace("{user.home}", _USER_HOME)
self._init_params = json.loads(data)
@@ -42,7 +42,7 @@ class LSPController(LSPControllerWebsocket):
self.read_lock = threading.Lock()
self.write_lock = threading.Lock()
def set_language(self, language):
def set_language(self, language: str):
self._language = language
def set_socket(self, socket: str):
@@ -51,15 +51,15 @@ class LSPController(LSPControllerWebsocket):
def unset_socket(self):
self._socket = None
def send_notification(self, method: str, params: {} = {}):
def send_notification(self, method: str, params: dict = {}):
self._send_message( ClientNotification(method, params) )
def send_request(self, method: str, params: {} = {}):
def send_request(self, method: str, params: dict = {}):
self._message_id += 1
self._event_history[self._message_id] = method
self._send_message( ClientRequest(self._message_id, method, params) )
def get_event_by_id(self, message_id: int):
def get_event_by_id(self, message_id: int) -> str:
if not message_id in self._event_history: return
return self._event_history[message_id]

View File

@@ -3,12 +3,12 @@
# Lib imports
# Application imports
from .lsp_controller_events import LSPControllerEvents
from .lsp_client_events import LSPClientEvents
from libs.dto.code.lsp.lsp_message_structs import ClientRequest, ClientNotification
class LSPControllerBase(LSPControllerEvents):
class LSPClientBase(LSPClientEvents):
def _send_message(self, data: ClientRequest or ClientNotification):
raise NotImplementedError

View File

@@ -2,7 +2,6 @@
import os
# Lib imports
from gi.repository import GLib
# Application imports
from libs.dto.code.lsp.lsp_messages import get_message_obj
@@ -17,7 +16,7 @@ from libs.dto.code.lsp.lsp_messages import symbols_request
class LSPControllerEvents:
class LSPClientEvents:
def send_initialize_message(self, init_ops: dict, workspace_file: str, workspace_uri: str):
folder_name = os.path.basename(workspace_file)
@@ -45,7 +44,7 @@ class LSPControllerEvents:
params["textDocument"]["languageId"] = data["language_id"]
params["textDocument"]["text"] = data["text"]
GLib.idle_add( self.send_notification, method, params )
self.send_notification( method, params )
def _lsp_did_save(self, data: dict):
method = "textDocument/didSave"
@@ -54,7 +53,7 @@ class LSPControllerEvents:
params["textDocument"]["uri"] = data["uri"]
params["text"] = data["text"]
GLib.idle_add( self.send_notification, method, params )
self.send_notification( method, params )
def _lsp_did_close(self, data: dict):
method = "textDocument/didClose"
@@ -62,7 +61,7 @@ class LSPControllerEvents:
params["textDocument"]["uri"] = data["uri"]
GLib.idle_add( self.send_notification, method, params )
self.send_notification( method, params )
def _lsp_did_change(self, data: dict):
method = "textDocument/didChange"
@@ -75,7 +74,7 @@ class LSPControllerEvents:
contentChanges = params["contentChanges"][0]
contentChanges["text"] = data["text"]
GLib.idle_add( self.send_notification, method, params )
self.send_notification( method, params )
# def _lsp_did_change(self, data: dict):
# method = "textDocument/didChange"
@@ -94,7 +93,7 @@ class LSPControllerEvents:
# end["line"] = data["line"]
# end["character"] = data["column"]
# GLib.idle_add( self.send_notification, method, params )
# self.send_notification( method, params )
def _lsp_definition(self, data: dict):
method = "textDocument/definition"
@@ -106,7 +105,7 @@ class LSPControllerEvents:
params["position"]["line"] = data["line"]
params["position"]["character"] = data["column"]
GLib.idle_add( self.send_request, method, params )
self.send_request( method, params )
def _lsp_completion(self, data: dict):
method = "textDocument/completion"
@@ -118,7 +117,7 @@ class LSPControllerEvents:
params["position"]["line"] = data["line"]
params["position"]["character"] = data["column"]
GLib.idle_add( self.send_request, method, params )
self.send_request( method, params )
def _lsp_java_class_file_contents(self, uri: str):
method = "java/classFileContents"
@@ -126,4 +125,4 @@ class LSPControllerEvents:
"uri": uri
}
GLib.idle_add( self.send_request, method, params )
self.send_request( method, params )

View File

@@ -1,23 +1,22 @@
# Python imports
import traceback
import subprocess
# Lib imports
from gi.repository import GLib
# Application imports
# from libs import websockets
from libs.dto.code.lsp.lsp_messages import LEN_HEADER, TYPE_HEADER, get_message_str, get_message_obj
from libs.dto.code.lsp.lsp_message_structs import \
LSPResponseTypes, ClientRequest, ClientNotification, LSPResponseRequest, LSPResponseNotification, LSPIDResponseNotification
from libs.dto.code.lsp.lsp_messages import get_message_str, get_message_obj
from libs.dto.code.lsp.lsp_message_structs import \
LSPResponseTypes, ClientRequest, ClientNotification, \
LSPResponseRequest, LSPResponseNotification, LSPIDResponseNotification
from .lsp_controller_base import LSPControllerBase
from .lsp_client_base import LSPClientBase
from .websocket_client import WebsocketClient
class LSPControllerWebsocket(LSPControllerBase):
def _send_message(self, data: ClientRequest or ClientNotification):
class LSPClientWebsocket(LSPClientBase):
def _send_message(self, data: ClientRequest | ClientNotification):
if not data: return
message_str = get_message_str(data)
@@ -39,7 +38,7 @@ class LSPControllerWebsocket(LSPControllerBase):
if not hasattr(self, "ws_client"): return
self.ws_client.close_client()
def _monitor_lsp_response(self, data: None or {}):
def _monitor_lsp_response(self, data: dict | None):
if not data: return
message = get_message_obj(data)

View File

@@ -1,3 +0,0 @@
"""
Plugin Controller Module
"""

View File

@@ -1,6 +0,0 @@
HTTP/1.1 101 WebSocket Protocol Handshake
Connection: Upgrade
Upgrade: WebSocket
Sec-WebSocket-Accept: Kxep+hNu9n51529fGidYu7a3wO0=
some_header: something

View File

@@ -1,6 +0,0 @@
HTTP/1.1 101 WebSocket Protocol Handshake
Connection: Upgrade
Upgrade WebSocket
Sec-WebSocket-Accept: Kxep+hNu9n51529fGidYu7a3wO0=
some_header: something

View File

@@ -1,8 +0,0 @@
HTTP/1.1 101 WebSocket Protocol Handshake
Connection: Upgrade, Keep-Alive
Upgrade: WebSocket
Sec-WebSocket-Accept: Kxep+hNu9n51529fGidYu7a3wO0=
Set-Cookie: Token=ABCDE
Set-Cookie: Token=FGHIJ
some_header: something

View File

@@ -1,23 +0,0 @@
#!/usr/bin/env python
# From https://github.com/aaugustin/websockets/blob/main/example/echo.py
import asyncio
import os
import websockets
LOCAL_WS_SERVER_PORT = int(os.environ.get("LOCAL_WS_SERVER_PORT", "8765"))
async def echo(websocket):
async for message in websocket:
await websocket.send(message)
async def main():
async with websockets.serve(echo, "localhost", LOCAL_WS_SERVER_PORT):
await asyncio.Future() # run forever
asyncio.run(main())

View File

@@ -1,125 +0,0 @@
# -*- coding: utf-8 -*-
#
import unittest
from websocket._abnf import ABNF, frame_buffer
from websocket._exceptions import WebSocketProtocolException
"""
test_abnf.py
websocket - WebSocket client library for Python
Copyright 2024 engn33r
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
class ABNFTest(unittest.TestCase):
def test_init(self):
a = ABNF(0, 0, 0, 0, opcode=ABNF.OPCODE_PING)
self.assertEqual(a.fin, 0)
self.assertEqual(a.rsv1, 0)
self.assertEqual(a.rsv2, 0)
self.assertEqual(a.rsv3, 0)
self.assertEqual(a.opcode, 9)
self.assertEqual(a.data, "")
a_bad = ABNF(0, 1, 0, 0, opcode=77)
self.assertEqual(a_bad.rsv1, 1)
self.assertEqual(a_bad.opcode, 77)
def test_validate(self):
a_invalid_ping = ABNF(0, 0, 0, 0, opcode=ABNF.OPCODE_PING)
self.assertRaises(
WebSocketProtocolException,
a_invalid_ping.validate,
skip_utf8_validation=False,
)
a_bad_rsv_value = ABNF(0, 1, 0, 0, opcode=ABNF.OPCODE_TEXT)
self.assertRaises(
WebSocketProtocolException,
a_bad_rsv_value.validate,
skip_utf8_validation=False,
)
a_bad_opcode = ABNF(0, 0, 0, 0, opcode=77)
self.assertRaises(
WebSocketProtocolException,
a_bad_opcode.validate,
skip_utf8_validation=False,
)
a_bad_close_frame = ABNF(0, 0, 0, 0, opcode=ABNF.OPCODE_CLOSE, data=b"\x01")
self.assertRaises(
WebSocketProtocolException,
a_bad_close_frame.validate,
skip_utf8_validation=False,
)
a_bad_close_frame_2 = ABNF(
0, 0, 0, 0, opcode=ABNF.OPCODE_CLOSE, data=b"\x01\x8a\xaa\xff\xdd"
)
self.assertRaises(
WebSocketProtocolException,
a_bad_close_frame_2.validate,
skip_utf8_validation=False,
)
a_bad_close_frame_3 = ABNF(
0, 0, 0, 0, opcode=ABNF.OPCODE_CLOSE, data=b"\x03\xe7"
)
self.assertRaises(
WebSocketProtocolException,
a_bad_close_frame_3.validate,
skip_utf8_validation=True,
)
def test_mask(self):
abnf_none_data = ABNF(
0, 0, 0, 0, opcode=ABNF.OPCODE_PING, mask_value=1, data=None
)
bytes_val = b"aaaa"
self.assertEqual(abnf_none_data._get_masked(bytes_val), bytes_val)
abnf_str_data = ABNF(
0, 0, 0, 0, opcode=ABNF.OPCODE_PING, mask_value=1, data="a"
)
self.assertEqual(abnf_str_data._get_masked(bytes_val), b"aaaa\x00")
def test_format(self):
abnf_bad_rsv_bits = ABNF(2, 0, 0, 0, opcode=ABNF.OPCODE_TEXT)
self.assertRaises(ValueError, abnf_bad_rsv_bits.format)
abnf_bad_opcode = ABNF(0, 0, 0, 0, opcode=5)
self.assertRaises(ValueError, abnf_bad_opcode.format)
abnf_length_10 = ABNF(0, 0, 0, 0, opcode=ABNF.OPCODE_TEXT, data="abcdefghij")
self.assertEqual(b"\x01", abnf_length_10.format()[0].to_bytes(1, "big"))
self.assertEqual(b"\x8a", abnf_length_10.format()[1].to_bytes(1, "big"))
self.assertEqual("fin=0 opcode=1 data=abcdefghij", abnf_length_10.__str__())
abnf_length_20 = ABNF(
0, 0, 0, 0, opcode=ABNF.OPCODE_BINARY, data="abcdefghijabcdefghij"
)
self.assertEqual(b"\x02", abnf_length_20.format()[0].to_bytes(1, "big"))
self.assertEqual(b"\x94", abnf_length_20.format()[1].to_bytes(1, "big"))
abnf_no_mask = ABNF(
0, 0, 0, 0, opcode=ABNF.OPCODE_TEXT, mask_value=0, data=b"\x01\x8a\xcc"
)
self.assertEqual(b"\x01\x03\x01\x8a\xcc", abnf_no_mask.format())
def test_frame_buffer(self):
fb = frame_buffer(0, True)
self.assertEqual(fb.recv, 0)
self.assertEqual(fb.skip_utf8_validation, True)
fb.clear
self.assertEqual(fb.header, None)
self.assertEqual(fb.length, None)
self.assertEqual(fb.mask_value, None)
self.assertEqual(fb.has_mask(), False)
if __name__ == "__main__":
unittest.main()

View File

@@ -1,352 +0,0 @@
# -*- coding: utf-8 -*-
#
import os
import os.path
import ssl
import threading
import unittest
import websocket as ws
"""
test_app.py
websocket - WebSocket client library for Python
Copyright 2024 engn33r
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
# Skip test to access the internet unless TEST_WITH_INTERNET == 1
TEST_WITH_INTERNET = os.environ.get("TEST_WITH_INTERNET", "0") == "1"
# Skip tests relying on local websockets server unless LOCAL_WS_SERVER_PORT != -1
LOCAL_WS_SERVER_PORT = os.environ.get("LOCAL_WS_SERVER_PORT", "-1")
TEST_WITH_LOCAL_SERVER = LOCAL_WS_SERVER_PORT != "-1"
TRACEABLE = True
class WebSocketAppTest(unittest.TestCase):
class NotSetYet:
"""A marker class for signalling that a value hasn't been set yet."""
def setUp(self):
ws.enableTrace(TRACEABLE)
WebSocketAppTest.keep_running_open = WebSocketAppTest.NotSetYet()
WebSocketAppTest.keep_running_close = WebSocketAppTest.NotSetYet()
WebSocketAppTest.get_mask_key_id = WebSocketAppTest.NotSetYet()
WebSocketAppTest.on_error_data = WebSocketAppTest.NotSetYet()
def tearDown(self):
WebSocketAppTest.keep_running_open = WebSocketAppTest.NotSetYet()
WebSocketAppTest.keep_running_close = WebSocketAppTest.NotSetYet()
WebSocketAppTest.get_mask_key_id = WebSocketAppTest.NotSetYet()
WebSocketAppTest.on_error_data = WebSocketAppTest.NotSetYet()
def close(self):
pass
@unittest.skipUnless(
TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled"
)
def test_keep_running(self):
"""A WebSocketApp should keep running as long as its self.keep_running
is not False (in the boolean context).
"""
def on_open(self, *args, **kwargs):
"""Set the keep_running flag for later inspection and immediately
close the connection.
"""
self.send("hello!")
WebSocketAppTest.keep_running_open = self.keep_running
self.keep_running = False
def on_message(_, message):
print(message)
self.close()
def on_close(self, *args, **kwargs):
"""Set the keep_running flag for the test to use."""
WebSocketAppTest.keep_running_close = self.keep_running
app = ws.WebSocketApp(
f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}",
on_open=on_open,
on_close=on_close,
on_message=on_message,
)
app.run_forever()
# @unittest.skipUnless(TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled")
@unittest.skipUnless(False, "Test disabled for now (requires rel)")
def test_run_forever_dispatcher(self):
"""A WebSocketApp should keep running as long as its self.keep_running
is not False (in the boolean context).
"""
def on_open(self, *args, **kwargs):
"""Send a message, receive, and send one more"""
self.send("hello!")
self.recv()
self.send("goodbye!")
def on_message(_, message):
print(message)
self.close()
app = ws.WebSocketApp(
f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}",
on_open=on_open,
on_message=on_message,
)
app.run_forever(dispatcher="Dispatcher") # doesn't work
# app.run_forever(dispatcher=rel) # would work
# rel.dispatch()
@unittest.skipUnless(
TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled"
)
def test_run_forever_teardown_clean_exit(self):
"""The WebSocketApp.run_forever() method should return `False` when the application ends gracefully."""
app = ws.WebSocketApp(f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}")
threading.Timer(interval=0.2, function=app.close).start()
teardown = app.run_forever()
self.assertEqual(teardown, False)
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def test_sock_mask_key(self):
"""A WebSocketApp should forward the received mask_key function down
to the actual socket.
"""
def my_mask_key_func():
return "\x00\x00\x00\x00"
app = ws.WebSocketApp(
"wss://api-pub.bitfinex.com/ws/1", get_mask_key=my_mask_key_func
)
# if numpy is installed, this assertion fail
# Note: We can't use 'is' for comparing the functions directly, need to use 'id'.
self.assertEqual(id(app.get_mask_key), id(my_mask_key_func))
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def test_invalid_ping_interval_ping_timeout(self):
"""Test exception handling if ping_interval < ping_timeout"""
def on_ping(app, _):
print("Got a ping!")
app.close()
def on_pong(app, _):
print("Got a pong! No need to respond")
app.close()
app = ws.WebSocketApp(
"wss://api-pub.bitfinex.com/ws/1", on_ping=on_ping, on_pong=on_pong
)
self.assertRaises(
ws.WebSocketException,
app.run_forever,
ping_interval=1,
ping_timeout=2,
sslopt={"cert_reqs": ssl.CERT_NONE},
)
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def test_ping_interval(self):
"""Test WebSocketApp proper ping functionality"""
def on_ping(app, _):
print("Got a ping!")
app.close()
def on_pong(app, _):
print("Got a pong! No need to respond")
app.close()
app = ws.WebSocketApp(
"wss://api-pub.bitfinex.com/ws/1", on_ping=on_ping, on_pong=on_pong
)
app.run_forever(
ping_interval=2, ping_timeout=1, sslopt={"cert_reqs": ssl.CERT_NONE}
)
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def test_opcode_close(self):
"""Test WebSocketApp close opcode"""
app = ws.WebSocketApp("wss://tsock.us1.twilio.com/v3/wsconnect")
app.run_forever(ping_interval=2, ping_timeout=1, ping_payload="Ping payload")
# This is commented out because the URL no longer responds in the expected way
# @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
# def testOpcodeBinary(self):
# """ Test WebSocketApp binary opcode
# """
# app = ws.WebSocketApp('wss://streaming.vn.teslamotors.com/streaming/')
# app.run_forever(ping_interval=2, ping_timeout=1, ping_payload="Ping payload")
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def test_bad_ping_interval(self):
"""A WebSocketApp handling of negative ping_interval"""
app = ws.WebSocketApp("wss://api-pub.bitfinex.com/ws/1")
self.assertRaises(
ws.WebSocketException,
app.run_forever,
ping_interval=-5,
sslopt={"cert_reqs": ssl.CERT_NONE},
)
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def test_bad_ping_timeout(self):
"""A WebSocketApp handling of negative ping_timeout"""
app = ws.WebSocketApp("wss://api-pub.bitfinex.com/ws/1")
self.assertRaises(
ws.WebSocketException,
app.run_forever,
ping_timeout=-3,
sslopt={"cert_reqs": ssl.CERT_NONE},
)
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def test_close_status_code(self):
"""Test extraction of close frame status code and close reason in WebSocketApp"""
def on_close(wsapp, close_status_code, close_msg):
print("on_close reached")
app = ws.WebSocketApp(
"wss://tsock.us1.twilio.com/v3/wsconnect", on_close=on_close
)
closeframe = ws.ABNF(
opcode=ws.ABNF.OPCODE_CLOSE, data=b"\x03\xe8no-init-from-client"
)
self.assertEqual([1000, "no-init-from-client"], app._get_close_args(closeframe))
closeframe = ws.ABNF(opcode=ws.ABNF.OPCODE_CLOSE, data=b"")
self.assertEqual([None, None], app._get_close_args(closeframe))
app2 = ws.WebSocketApp("wss://tsock.us1.twilio.com/v3/wsconnect")
closeframe = ws.ABNF(opcode=ws.ABNF.OPCODE_CLOSE, data=b"")
self.assertEqual([None, None], app2._get_close_args(closeframe))
self.assertRaises(
ws.WebSocketConnectionClosedException,
app.send,
data="test if connection is closed",
)
@unittest.skipUnless(
TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled"
)
def test_callback_function_exception(self):
"""Test callback function exception handling"""
exc = None
passed_app = None
def on_open(app):
raise RuntimeError("Callback failed")
def on_error(app, err):
nonlocal passed_app
passed_app = app
nonlocal exc
exc = err
def on_pong(app, _):
app.close()
app = ws.WebSocketApp(
f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}",
on_open=on_open,
on_error=on_error,
on_pong=on_pong,
)
app.run_forever(ping_interval=2, ping_timeout=1)
self.assertEqual(passed_app, app)
self.assertIsInstance(exc, RuntimeError)
self.assertEqual(str(exc), "Callback failed")
@unittest.skipUnless(
TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled"
)
def test_callback_method_exception(self):
"""Test callback method exception handling"""
class Callbacks:
def __init__(self):
self.exc = None
self.passed_app = None
self.app = ws.WebSocketApp(
f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}",
on_open=self.on_open,
on_error=self.on_error,
on_pong=self.on_pong,
)
self.app.run_forever(ping_interval=2, ping_timeout=1)
def on_open(self, _):
raise RuntimeError("Callback failed")
def on_error(self, app, err):
self.passed_app = app
self.exc = err
def on_pong(self, app, _):
app.close()
callbacks = Callbacks()
self.assertEqual(callbacks.passed_app, callbacks.app)
self.assertIsInstance(callbacks.exc, RuntimeError)
self.assertEqual(str(callbacks.exc), "Callback failed")
@unittest.skipUnless(
TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled"
)
def test_reconnect(self):
"""Test reconnect"""
pong_count = 0
exc = None
def on_error(_, err):
nonlocal exc
exc = err
def on_pong(app, _):
nonlocal pong_count
pong_count += 1
if pong_count == 1:
# First pong, shutdown socket, enforce read error
app.sock.shutdown()
if pong_count >= 2:
# Got second pong after reconnect
app.close()
app = ws.WebSocketApp(
f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}", on_pong=on_pong, on_error=on_error
)
app.run_forever(ping_interval=2, ping_timeout=1, reconnect=3)
self.assertEqual(pong_count, 2)
self.assertIsInstance(exc, ws.WebSocketTimeoutException)
self.assertEqual(str(exc), "ping/pong timed out")
if __name__ == "__main__":
unittest.main()

View File

@@ -1,123 +0,0 @@
import unittest
from websocket._cookiejar import SimpleCookieJar
"""
test_cookiejar.py
websocket - WebSocket client library for Python
Copyright 2024 engn33r
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
class CookieJarTest(unittest.TestCase):
def test_add(self):
cookie_jar = SimpleCookieJar()
cookie_jar.add("")
self.assertFalse(
cookie_jar.jar, "Cookie with no domain should not be added to the jar"
)
cookie_jar = SimpleCookieJar()
cookie_jar.add("a=b")
self.assertFalse(
cookie_jar.jar, "Cookie with no domain should not be added to the jar"
)
cookie_jar = SimpleCookieJar()
cookie_jar.add("a=b; domain=.abc")
self.assertTrue(".abc" in cookie_jar.jar)
cookie_jar = SimpleCookieJar()
cookie_jar.add("a=b; domain=abc")
self.assertTrue(".abc" in cookie_jar.jar)
self.assertTrue("abc" not in cookie_jar.jar)
cookie_jar = SimpleCookieJar()
cookie_jar.add("a=b; c=d; domain=abc")
self.assertEqual(cookie_jar.get("abc"), "a=b; c=d")
self.assertEqual(cookie_jar.get(None), "")
cookie_jar = SimpleCookieJar()
cookie_jar.add("a=b; c=d; domain=abc")
cookie_jar.add("e=f; domain=abc")
self.assertEqual(cookie_jar.get("abc"), "a=b; c=d; e=f")
cookie_jar = SimpleCookieJar()
cookie_jar.add("a=b; c=d; domain=abc")
cookie_jar.add("e=f; domain=.abc")
self.assertEqual(cookie_jar.get("abc"), "a=b; c=d; e=f")
cookie_jar = SimpleCookieJar()
cookie_jar.add("a=b; c=d; domain=abc")
cookie_jar.add("e=f; domain=xyz")
self.assertEqual(cookie_jar.get("abc"), "a=b; c=d")
self.assertEqual(cookie_jar.get("xyz"), "e=f")
self.assertEqual(cookie_jar.get("something"), "")
def test_set(self):
cookie_jar = SimpleCookieJar()
cookie_jar.set("a=b")
self.assertFalse(
cookie_jar.jar, "Cookie with no domain should not be added to the jar"
)
cookie_jar = SimpleCookieJar()
cookie_jar.set("a=b; domain=.abc")
self.assertTrue(".abc" in cookie_jar.jar)
cookie_jar = SimpleCookieJar()
cookie_jar.set("a=b; domain=abc")
self.assertTrue(".abc" in cookie_jar.jar)
self.assertTrue("abc" not in cookie_jar.jar)
cookie_jar = SimpleCookieJar()
cookie_jar.set("a=b; c=d; domain=abc")
self.assertEqual(cookie_jar.get("abc"), "a=b; c=d")
cookie_jar = SimpleCookieJar()
cookie_jar.set("a=b; c=d; domain=abc")
cookie_jar.set("e=f; domain=abc")
self.assertEqual(cookie_jar.get("abc"), "e=f")
cookie_jar = SimpleCookieJar()
cookie_jar.set("a=b; c=d; domain=abc")
cookie_jar.set("e=f; domain=.abc")
self.assertEqual(cookie_jar.get("abc"), "e=f")
cookie_jar = SimpleCookieJar()
cookie_jar.set("a=b; c=d; domain=abc")
cookie_jar.set("e=f; domain=xyz")
self.assertEqual(cookie_jar.get("abc"), "a=b; c=d")
self.assertEqual(cookie_jar.get("xyz"), "e=f")
self.assertEqual(cookie_jar.get("something"), "")
def test_get(self):
cookie_jar = SimpleCookieJar()
cookie_jar.set("a=b; c=d; domain=abc.com")
self.assertEqual(cookie_jar.get("abc.com"), "a=b; c=d")
self.assertEqual(cookie_jar.get("x.abc.com"), "a=b; c=d")
self.assertEqual(cookie_jar.get("abc.com.es"), "")
self.assertEqual(cookie_jar.get("xabc.com"), "")
cookie_jar.set("a=b; c=d; domain=.abc.com")
self.assertEqual(cookie_jar.get("abc.com"), "a=b; c=d")
self.assertEqual(cookie_jar.get("x.abc.com"), "a=b; c=d")
self.assertEqual(cookie_jar.get("abc.com.es"), "")
self.assertEqual(cookie_jar.get("xabc.com"), "")
if __name__ == "__main__":
unittest.main()

View File

@@ -1,370 +0,0 @@
# -*- coding: utf-8 -*-
#
import os
import os.path
import socket
import ssl
import unittest
import websocket
from websocket._exceptions import WebSocketProxyException, WebSocketException
from websocket._http import (
_get_addrinfo_list,
_start_proxied_socket,
_tunnel,
connect,
proxy_info,
read_headers,
HAVE_PYTHON_SOCKS,
)
"""
test_http.py
websocket - WebSocket client library for Python
Copyright 2024 engn33r
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
try:
from python_socks._errors import ProxyConnectionError, ProxyError, ProxyTimeoutError
except:
from websocket._http import ProxyConnectionError, ProxyError, ProxyTimeoutError
# Skip test to access the internet unless TEST_WITH_INTERNET == 1
TEST_WITH_INTERNET = os.environ.get("TEST_WITH_INTERNET", "0") == "1"
TEST_WITH_PROXY = os.environ.get("TEST_WITH_PROXY", "0") == "1"
# Skip tests relying on local websockets server unless LOCAL_WS_SERVER_PORT != -1
LOCAL_WS_SERVER_PORT = os.environ.get("LOCAL_WS_SERVER_PORT", "-1")
TEST_WITH_LOCAL_SERVER = LOCAL_WS_SERVER_PORT != "-1"
class SockMock:
def __init__(self):
self.data = []
self.sent = []
def add_packet(self, data):
self.data.append(data)
def gettimeout(self):
return None
def recv(self, bufsize):
if self.data:
e = self.data.pop(0)
if isinstance(e, Exception):
raise e
if len(e) > bufsize:
self.data.insert(0, e[bufsize:])
return e[:bufsize]
def send(self, data):
self.sent.append(data)
return len(data)
def close(self):
pass
class HeaderSockMock(SockMock):
def __init__(self, fname):
SockMock.__init__(self)
path = os.path.join(os.path.dirname(__file__), fname)
with open(path, "rb") as f:
self.add_packet(f.read())
class OptsList:
def __init__(self):
self.timeout = 1
self.sockopt = []
self.sslopt = {"cert_reqs": ssl.CERT_NONE}
class HttpTest(unittest.TestCase):
def test_read_header(self):
status, header, _ = read_headers(HeaderSockMock("data/header01.txt"))
self.assertEqual(status, 101)
self.assertEqual(header["connection"], "Upgrade")
# header02.txt is intentionally malformed
self.assertRaises(
WebSocketException, read_headers, HeaderSockMock("data/header02.txt")
)
def test_tunnel(self):
self.assertRaises(
WebSocketProxyException,
_tunnel,
HeaderSockMock("data/header01.txt"),
"example.com",
80,
("username", "password"),
)
self.assertRaises(
WebSocketProxyException,
_tunnel,
HeaderSockMock("data/header02.txt"),
"example.com",
80,
("username", "password"),
)
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def test_connect(self):
# Not currently testing an actual proxy connection, so just check whether proxy errors are raised. This requires internet for a DNS lookup
if HAVE_PYTHON_SOCKS:
# Need this check, otherwise case where python_socks is not installed triggers
# websocket._exceptions.WebSocketException: Python Socks is needed for SOCKS proxying but is not available
self.assertRaises(
(ProxyTimeoutError, OSError),
_start_proxied_socket,
"wss://example.com",
OptsList(),
proxy_info(
http_proxy_host="example.com",
http_proxy_port="8080",
proxy_type="socks4",
http_proxy_timeout=1,
),
)
self.assertRaises(
(ProxyTimeoutError, OSError),
_start_proxied_socket,
"wss://example.com",
OptsList(),
proxy_info(
http_proxy_host="example.com",
http_proxy_port="8080",
proxy_type="socks4a",
http_proxy_timeout=1,
),
)
self.assertRaises(
(ProxyTimeoutError, OSError),
_start_proxied_socket,
"wss://example.com",
OptsList(),
proxy_info(
http_proxy_host="example.com",
http_proxy_port="8080",
proxy_type="socks5",
http_proxy_timeout=1,
),
)
self.assertRaises(
(ProxyTimeoutError, OSError),
_start_proxied_socket,
"wss://example.com",
OptsList(),
proxy_info(
http_proxy_host="example.com",
http_proxy_port="8080",
proxy_type="socks5h",
http_proxy_timeout=1,
),
)
self.assertRaises(
ProxyConnectionError,
connect,
"wss://example.com",
OptsList(),
proxy_info(
http_proxy_host="127.0.0.1",
http_proxy_port=9999,
proxy_type="socks4",
http_proxy_timeout=1,
),
None,
)
self.assertRaises(
TypeError,
_get_addrinfo_list,
None,
80,
True,
proxy_info(
http_proxy_host="127.0.0.1", http_proxy_port="9999", proxy_type="http"
),
)
self.assertRaises(
TypeError,
_get_addrinfo_list,
None,
80,
True,
proxy_info(
http_proxy_host="127.0.0.1", http_proxy_port="9999", proxy_type="http"
),
)
self.assertRaises(
socket.timeout,
connect,
"wss://google.com",
OptsList(),
proxy_info(
http_proxy_host="8.8.8.8",
http_proxy_port=9999,
proxy_type="http",
http_proxy_timeout=1,
),
None,
)
self.assertEqual(
connect(
"wss://google.com",
OptsList(),
proxy_info(
http_proxy_host="8.8.8.8", http_proxy_port=8080, proxy_type="http"
),
True,
),
(True, ("google.com", 443, "/")),
)
# The following test fails on Mac OS with a gaierror, not an OverflowError
# self.assertRaises(OverflowError, connect, "wss://example.com", OptsList(), proxy_info(http_proxy_host="127.0.0.1", http_proxy_port=99999, proxy_type="socks4", timeout=2), False)
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
@unittest.skipUnless(
TEST_WITH_PROXY, "This test requires a HTTP proxy to be running on port 8899"
)
@unittest.skipUnless(
TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled"
)
def test_proxy_connect(self):
ws = websocket.WebSocket()
ws.connect(
f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}",
http_proxy_host="127.0.0.1",
http_proxy_port="8899",
proxy_type="http",
)
ws.send("Hello, Server")
server_response = ws.recv()
self.assertEqual(server_response, "Hello, Server")
# self.assertEqual(_start_proxied_socket("wss://api.bitfinex.com/ws/2", OptsList(), proxy_info(http_proxy_host="127.0.0.1", http_proxy_port="8899", proxy_type="http"))[1], ("api.bitfinex.com", 443, '/ws/2'))
self.assertEqual(
_get_addrinfo_list(
"api.bitfinex.com",
443,
True,
proxy_info(
http_proxy_host="127.0.0.1",
http_proxy_port="8899",
proxy_type="http",
),
),
(
socket.getaddrinfo(
"127.0.0.1", 8899, 0, socket.SOCK_STREAM, socket.SOL_TCP
),
True,
None,
),
)
self.assertEqual(
connect(
"wss://api.bitfinex.com/ws/2",
OptsList(),
proxy_info(
http_proxy_host="127.0.0.1", http_proxy_port=8899, proxy_type="http"
),
None,
)[1],
("api.bitfinex.com", 443, "/ws/2"),
)
# TODO: Test SOCKS4 and SOCK5 proxies with unit tests
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def test_sslopt(self):
ssloptions = {
"check_hostname": False,
"server_hostname": "ServerName",
"ssl_version": ssl.PROTOCOL_TLS_CLIENT,
"ciphers": "TLS_AES_256_GCM_SHA384:TLS_CHACHA20_POLY1305_SHA256:\
TLS_AES_128_GCM_SHA256:ECDHE-ECDSA-AES256-GCM-SHA384:\
ECDHE-RSA-AES256-GCM-SHA384:DHE-RSA-AES256-GCM-SHA384:\
ECDHE-ECDSA-CHACHA20-POLY1305:ECDHE-RSA-CHACHA20-POLY1305:\
DHE-RSA-CHACHA20-POLY1305:ECDHE-ECDSA-AES128-GCM-SHA256:\
ECDHE-RSA-AES128-GCM-SHA256:DHE-RSA-AES128-GCM-SHA256:\
ECDHE-ECDSA-AES256-SHA384:ECDHE-RSA-AES256-SHA384:\
DHE-RSA-AES256-SHA256:ECDHE-ECDSA-AES128-SHA256:\
ECDHE-RSA-AES128-SHA256:DHE-RSA-AES128-SHA256:\
ECDHE-ECDSA-AES256-SHA:ECDHE-RSA-AES256-SHA",
"ecdh_curve": "prime256v1",
}
ws_ssl1 = websocket.WebSocket(sslopt=ssloptions)
ws_ssl1.connect("wss://api.bitfinex.com/ws/2")
ws_ssl1.send("Hello")
ws_ssl1.close()
ws_ssl2 = websocket.WebSocket(sslopt={"check_hostname": True})
ws_ssl2.connect("wss://api.bitfinex.com/ws/2")
ws_ssl2.close
def test_proxy_info(self):
self.assertEqual(
proxy_info(
http_proxy_host="127.0.0.1", http_proxy_port="8080", proxy_type="http"
).proxy_protocol,
"http",
)
self.assertRaises(
ProxyError,
proxy_info,
http_proxy_host="127.0.0.1",
http_proxy_port="8080",
proxy_type="badval",
)
self.assertEqual(
proxy_info(
http_proxy_host="example.com", http_proxy_port="8080", proxy_type="http"
).proxy_host,
"example.com",
)
self.assertEqual(
proxy_info(
http_proxy_host="127.0.0.1", http_proxy_port="8080", proxy_type="http"
).proxy_port,
"8080",
)
self.assertEqual(
proxy_info(
http_proxy_host="127.0.0.1", http_proxy_port="8080", proxy_type="http"
).auth,
None,
)
self.assertEqual(
proxy_info(
http_proxy_host="127.0.0.1",
http_proxy_port="8080",
proxy_type="http",
http_proxy_auth=("my_username123", "my_pass321"),
).auth[0],
"my_username123",
)
self.assertEqual(
proxy_info(
http_proxy_host="127.0.0.1",
http_proxy_port="8080",
proxy_type="http",
http_proxy_auth=("my_username123", "my_pass321"),
).auth[1],
"my_pass321",
)
if __name__ == "__main__":
unittest.main()

View File

@@ -1,464 +0,0 @@
# -*- coding: utf-8 -*-
#
import os
import unittest
from websocket._url import (
_is_address_in_network,
_is_no_proxy_host,
get_proxy_info,
parse_url,
)
from websocket._exceptions import WebSocketProxyException
"""
test_url.py
websocket - WebSocket client library for Python
Copyright 2024 engn33r
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
class UrlTest(unittest.TestCase):
def test_address_in_network(self):
self.assertTrue(_is_address_in_network("127.0.0.1", "127.0.0.0/8"))
self.assertTrue(_is_address_in_network("127.1.0.1", "127.0.0.0/8"))
self.assertFalse(_is_address_in_network("127.1.0.1", "127.0.0.0/24"))
def test_parse_url(self):
p = parse_url("ws://www.example.com/r")
self.assertEqual(p[0], "www.example.com")
self.assertEqual(p[1], 80)
self.assertEqual(p[2], "/r")
self.assertEqual(p[3], False)
p = parse_url("ws://www.example.com/r/")
self.assertEqual(p[0], "www.example.com")
self.assertEqual(p[1], 80)
self.assertEqual(p[2], "/r/")
self.assertEqual(p[3], False)
p = parse_url("ws://www.example.com/")
self.assertEqual(p[0], "www.example.com")
self.assertEqual(p[1], 80)
self.assertEqual(p[2], "/")
self.assertEqual(p[3], False)
p = parse_url("ws://www.example.com")
self.assertEqual(p[0], "www.example.com")
self.assertEqual(p[1], 80)
self.assertEqual(p[2], "/")
self.assertEqual(p[3], False)
p = parse_url("ws://www.example.com:8080/r")
self.assertEqual(p[0], "www.example.com")
self.assertEqual(p[1], 8080)
self.assertEqual(p[2], "/r")
self.assertEqual(p[3], False)
p = parse_url("ws://www.example.com:8080/")
self.assertEqual(p[0], "www.example.com")
self.assertEqual(p[1], 8080)
self.assertEqual(p[2], "/")
self.assertEqual(p[3], False)
p = parse_url("ws://www.example.com:8080")
self.assertEqual(p[0], "www.example.com")
self.assertEqual(p[1], 8080)
self.assertEqual(p[2], "/")
self.assertEqual(p[3], False)
p = parse_url("wss://www.example.com:8080/r")
self.assertEqual(p[0], "www.example.com")
self.assertEqual(p[1], 8080)
self.assertEqual(p[2], "/r")
self.assertEqual(p[3], True)
p = parse_url("wss://www.example.com:8080/r?key=value")
self.assertEqual(p[0], "www.example.com")
self.assertEqual(p[1], 8080)
self.assertEqual(p[2], "/r?key=value")
self.assertEqual(p[3], True)
self.assertRaises(ValueError, parse_url, "http://www.example.com/r")
p = parse_url("ws://[2a03:4000:123:83::3]/r")
self.assertEqual(p[0], "2a03:4000:123:83::3")
self.assertEqual(p[1], 80)
self.assertEqual(p[2], "/r")
self.assertEqual(p[3], False)
p = parse_url("ws://[2a03:4000:123:83::3]:8080/r")
self.assertEqual(p[0], "2a03:4000:123:83::3")
self.assertEqual(p[1], 8080)
self.assertEqual(p[2], "/r")
self.assertEqual(p[3], False)
p = parse_url("wss://[2a03:4000:123:83::3]/r")
self.assertEqual(p[0], "2a03:4000:123:83::3")
self.assertEqual(p[1], 443)
self.assertEqual(p[2], "/r")
self.assertEqual(p[3], True)
p = parse_url("wss://[2a03:4000:123:83::3]:8080/r")
self.assertEqual(p[0], "2a03:4000:123:83::3")
self.assertEqual(p[1], 8080)
self.assertEqual(p[2], "/r")
self.assertEqual(p[3], True)
class IsNoProxyHostTest(unittest.TestCase):
def setUp(self):
self.no_proxy = os.environ.get("no_proxy", None)
if "no_proxy" in os.environ:
del os.environ["no_proxy"]
def tearDown(self):
if self.no_proxy:
os.environ["no_proxy"] = self.no_proxy
elif "no_proxy" in os.environ:
del os.environ["no_proxy"]
def test_match_all(self):
self.assertTrue(_is_no_proxy_host("any.websocket.org", ["*"]))
self.assertTrue(_is_no_proxy_host("192.168.0.1", ["*"]))
self.assertFalse(_is_no_proxy_host("192.168.0.1", ["192.168.1.1"]))
self.assertFalse(
_is_no_proxy_host("any.websocket.org", ["other.websocket.org"])
)
self.assertTrue(
_is_no_proxy_host("any.websocket.org", ["other.websocket.org", "*"])
)
os.environ["no_proxy"] = "*"
self.assertTrue(_is_no_proxy_host("any.websocket.org", None))
self.assertTrue(_is_no_proxy_host("192.168.0.1", None))
os.environ["no_proxy"] = "other.websocket.org, *"
self.assertTrue(_is_no_proxy_host("any.websocket.org", None))
def test_ip_address(self):
self.assertTrue(_is_no_proxy_host("127.0.0.1", ["127.0.0.1"]))
self.assertFalse(_is_no_proxy_host("127.0.0.2", ["127.0.0.1"]))
self.assertTrue(
_is_no_proxy_host("127.0.0.1", ["other.websocket.org", "127.0.0.1"])
)
self.assertFalse(
_is_no_proxy_host("127.0.0.2", ["other.websocket.org", "127.0.0.1"])
)
os.environ["no_proxy"] = "127.0.0.1"
self.assertTrue(_is_no_proxy_host("127.0.0.1", None))
self.assertFalse(_is_no_proxy_host("127.0.0.2", None))
os.environ["no_proxy"] = "other.websocket.org, 127.0.0.1"
self.assertTrue(_is_no_proxy_host("127.0.0.1", None))
self.assertFalse(_is_no_proxy_host("127.0.0.2", None))
def test_ip_address_in_range(self):
self.assertTrue(_is_no_proxy_host("127.0.0.1", ["127.0.0.0/8"]))
self.assertTrue(_is_no_proxy_host("127.0.0.2", ["127.0.0.0/8"]))
self.assertFalse(_is_no_proxy_host("127.1.0.1", ["127.0.0.0/24"]))
os.environ["no_proxy"] = "127.0.0.0/8"
self.assertTrue(_is_no_proxy_host("127.0.0.1", None))
self.assertTrue(_is_no_proxy_host("127.0.0.2", None))
os.environ["no_proxy"] = "127.0.0.0/24"
self.assertFalse(_is_no_proxy_host("127.1.0.1", None))
def test_hostname_match(self):
self.assertTrue(_is_no_proxy_host("my.websocket.org", ["my.websocket.org"]))
self.assertTrue(
_is_no_proxy_host(
"my.websocket.org", ["other.websocket.org", "my.websocket.org"]
)
)
self.assertFalse(_is_no_proxy_host("my.websocket.org", ["other.websocket.org"]))
os.environ["no_proxy"] = "my.websocket.org"
self.assertTrue(_is_no_proxy_host("my.websocket.org", None))
self.assertFalse(_is_no_proxy_host("other.websocket.org", None))
os.environ["no_proxy"] = "other.websocket.org, my.websocket.org"
self.assertTrue(_is_no_proxy_host("my.websocket.org", None))
def test_hostname_match_domain(self):
self.assertTrue(_is_no_proxy_host("any.websocket.org", [".websocket.org"]))
self.assertTrue(_is_no_proxy_host("my.other.websocket.org", [".websocket.org"]))
self.assertTrue(
_is_no_proxy_host(
"any.websocket.org", ["my.websocket.org", ".websocket.org"]
)
)
self.assertFalse(_is_no_proxy_host("any.websocket.com", [".websocket.org"]))
os.environ["no_proxy"] = ".websocket.org"
self.assertTrue(_is_no_proxy_host("any.websocket.org", None))
self.assertTrue(_is_no_proxy_host("my.other.websocket.org", None))
self.assertFalse(_is_no_proxy_host("any.websocket.com", None))
os.environ["no_proxy"] = "my.websocket.org, .websocket.org"
self.assertTrue(_is_no_proxy_host("any.websocket.org", None))
class ProxyInfoTest(unittest.TestCase):
def setUp(self):
self.http_proxy = os.environ.get("http_proxy", None)
self.https_proxy = os.environ.get("https_proxy", None)
self.no_proxy = os.environ.get("no_proxy", None)
if "http_proxy" in os.environ:
del os.environ["http_proxy"]
if "https_proxy" in os.environ:
del os.environ["https_proxy"]
if "no_proxy" in os.environ:
del os.environ["no_proxy"]
def tearDown(self):
if self.http_proxy:
os.environ["http_proxy"] = self.http_proxy
elif "http_proxy" in os.environ:
del os.environ["http_proxy"]
if self.https_proxy:
os.environ["https_proxy"] = self.https_proxy
elif "https_proxy" in os.environ:
del os.environ["https_proxy"]
if self.no_proxy:
os.environ["no_proxy"] = self.no_proxy
elif "no_proxy" in os.environ:
del os.environ["no_proxy"]
def test_proxy_from_args(self):
self.assertRaises(
WebSocketProxyException,
get_proxy_info,
"echo.websocket.events",
False,
proxy_host="localhost",
)
self.assertEqual(
get_proxy_info(
"echo.websocket.events", False, proxy_host="localhost", proxy_port=3128
),
("localhost", 3128, None),
)
self.assertEqual(
get_proxy_info(
"echo.websocket.events", True, proxy_host="localhost", proxy_port=3128
),
("localhost", 3128, None),
)
self.assertEqual(
get_proxy_info(
"echo.websocket.events",
False,
proxy_host="localhost",
proxy_port=9001,
proxy_auth=("a", "b"),
),
("localhost", 9001, ("a", "b")),
)
self.assertEqual(
get_proxy_info(
"echo.websocket.events",
False,
proxy_host="localhost",
proxy_port=3128,
proxy_auth=("a", "b"),
),
("localhost", 3128, ("a", "b")),
)
self.assertEqual(
get_proxy_info(
"echo.websocket.events",
True,
proxy_host="localhost",
proxy_port=8765,
proxy_auth=("a", "b"),
),
("localhost", 8765, ("a", "b")),
)
self.assertEqual(
get_proxy_info(
"echo.websocket.events",
True,
proxy_host="localhost",
proxy_port=3128,
proxy_auth=("a", "b"),
),
("localhost", 3128, ("a", "b")),
)
self.assertEqual(
get_proxy_info(
"echo.websocket.events",
True,
proxy_host="localhost",
proxy_port=3128,
no_proxy=["example.com"],
proxy_auth=("a", "b"),
),
("localhost", 3128, ("a", "b")),
)
self.assertEqual(
get_proxy_info(
"echo.websocket.events",
True,
proxy_host="localhost",
proxy_port=3128,
no_proxy=["echo.websocket.events"],
proxy_auth=("a", "b"),
),
(None, 0, None),
)
self.assertEqual(
get_proxy_info(
"echo.websocket.events",
True,
proxy_host="localhost",
proxy_port=3128,
no_proxy=[".websocket.events"],
),
(None, 0, None),
)
def test_proxy_from_env(self):
os.environ["http_proxy"] = "http://localhost/"
self.assertEqual(
get_proxy_info("echo.websocket.events", False), ("localhost", None, None)
)
os.environ["http_proxy"] = "http://localhost:3128/"
self.assertEqual(
get_proxy_info("echo.websocket.events", False), ("localhost", 3128, None)
)
os.environ["http_proxy"] = "http://localhost/"
os.environ["https_proxy"] = "http://localhost2/"
self.assertEqual(
get_proxy_info("echo.websocket.events", False), ("localhost", None, None)
)
os.environ["http_proxy"] = "http://localhost:3128/"
os.environ["https_proxy"] = "http://localhost2:3128/"
self.assertEqual(
get_proxy_info("echo.websocket.events", False), ("localhost", 3128, None)
)
os.environ["http_proxy"] = "http://localhost/"
os.environ["https_proxy"] = "http://localhost2/"
self.assertEqual(
get_proxy_info("echo.websocket.events", True), ("localhost2", None, None)
)
os.environ["http_proxy"] = "http://localhost:3128/"
os.environ["https_proxy"] = "http://localhost2:3128/"
self.assertEqual(
get_proxy_info("echo.websocket.events", True), ("localhost2", 3128, None)
)
os.environ["http_proxy"] = ""
os.environ["https_proxy"] = "http://localhost2/"
self.assertEqual(
get_proxy_info("echo.websocket.events", True), ("localhost2", None, None)
)
self.assertEqual(
get_proxy_info("echo.websocket.events", False), (None, 0, None)
)
os.environ["http_proxy"] = ""
os.environ["https_proxy"] = "http://localhost2:3128/"
self.assertEqual(
get_proxy_info("echo.websocket.events", True), ("localhost2", 3128, None)
)
self.assertEqual(
get_proxy_info("echo.websocket.events", False), (None, 0, None)
)
os.environ["http_proxy"] = "http://localhost/"
os.environ["https_proxy"] = ""
self.assertEqual(get_proxy_info("echo.websocket.events", True), (None, 0, None))
self.assertEqual(
get_proxy_info("echo.websocket.events", False), ("localhost", None, None)
)
os.environ["http_proxy"] = "http://localhost:3128/"
os.environ["https_proxy"] = ""
self.assertEqual(get_proxy_info("echo.websocket.events", True), (None, 0, None))
self.assertEqual(
get_proxy_info("echo.websocket.events", False), ("localhost", 3128, None)
)
os.environ["http_proxy"] = "http://a:b@localhost/"
self.assertEqual(
get_proxy_info("echo.websocket.events", False),
("localhost", None, ("a", "b")),
)
os.environ["http_proxy"] = "http://a:b@localhost:3128/"
self.assertEqual(
get_proxy_info("echo.websocket.events", False),
("localhost", 3128, ("a", "b")),
)
os.environ["http_proxy"] = "http://a:b@localhost/"
os.environ["https_proxy"] = "http://a:b@localhost2/"
self.assertEqual(
get_proxy_info("echo.websocket.events", False),
("localhost", None, ("a", "b")),
)
os.environ["http_proxy"] = "http://a:b@localhost:3128/"
os.environ["https_proxy"] = "http://a:b@localhost2:3128/"
self.assertEqual(
get_proxy_info("echo.websocket.events", False),
("localhost", 3128, ("a", "b")),
)
os.environ["http_proxy"] = "http://a:b@localhost/"
os.environ["https_proxy"] = "http://a:b@localhost2/"
self.assertEqual(
get_proxy_info("echo.websocket.events", True),
("localhost2", None, ("a", "b")),
)
os.environ["http_proxy"] = "http://a:b@localhost:3128/"
os.environ["https_proxy"] = "http://a:b@localhost2:3128/"
self.assertEqual(
get_proxy_info("echo.websocket.events", True),
("localhost2", 3128, ("a", "b")),
)
os.environ["http_proxy"] = (
"http://john%40example.com:P%40SSWORD@localhost:3128/"
)
os.environ["https_proxy"] = (
"http://john%40example.com:P%40SSWORD@localhost2:3128/"
)
self.assertEqual(
get_proxy_info("echo.websocket.events", True),
("localhost2", 3128, ("john@example.com", "P@SSWORD")),
)
os.environ["http_proxy"] = "http://a:b@localhost/"
os.environ["https_proxy"] = "http://a:b@localhost2/"
os.environ["no_proxy"] = "example1.com,example2.com"
self.assertEqual(
get_proxy_info("example.1.com", True), ("localhost2", None, ("a", "b"))
)
os.environ["http_proxy"] = "http://a:b@localhost:3128/"
os.environ["https_proxy"] = "http://a:b@localhost2:3128/"
os.environ["no_proxy"] = "example1.com,example2.com, echo.websocket.events"
self.assertEqual(get_proxy_info("echo.websocket.events", True), (None, 0, None))
os.environ["http_proxy"] = "http://a:b@localhost:3128/"
os.environ["https_proxy"] = "http://a:b@localhost2:3128/"
os.environ["no_proxy"] = "example1.com,example2.com, .websocket.events"
self.assertEqual(get_proxy_info("echo.websocket.events", True), (None, 0, None))
os.environ["http_proxy"] = "http://a:b@localhost:3128/"
os.environ["https_proxy"] = "http://a:b@localhost2:3128/"
os.environ["no_proxy"] = "127.0.0.0/8, 192.168.0.0/16"
self.assertEqual(get_proxy_info("127.0.0.1", False), (None, 0, None))
self.assertEqual(get_proxy_info("192.168.1.1", False), (None, 0, None))
if __name__ == "__main__":
unittest.main()

View File

@@ -1,497 +0,0 @@
# -*- coding: utf-8 -*-
#
import os
import os.path
import socket
import unittest
from base64 import decodebytes as base64decode
import websocket as ws
from websocket._exceptions import WebSocketBadStatusException, WebSocketAddressException
from websocket._handshake import _create_sec_websocket_key
from websocket._handshake import _validate as _validate_header
from websocket._http import read_headers
from websocket._utils import validate_utf8
"""
test_websocket.py
websocket - WebSocket client library for Python
Copyright 2024 engn33r
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
try:
import ssl
except ImportError:
# dummy class of SSLError for ssl none-support environment.
class SSLError(Exception):
pass
# Skip test to access the internet unless TEST_WITH_INTERNET == 1
TEST_WITH_INTERNET = os.environ.get("TEST_WITH_INTERNET", "0") == "1"
# Skip tests relying on local websockets server unless LOCAL_WS_SERVER_PORT != -1
LOCAL_WS_SERVER_PORT = os.environ.get("LOCAL_WS_SERVER_PORT", "-1")
TEST_WITH_LOCAL_SERVER = LOCAL_WS_SERVER_PORT != "-1"
TRACEABLE = True
def create_mask_key(_):
return "abcd"
class SockMock:
def __init__(self):
self.data = []
self.sent = []
def add_packet(self, data):
self.data.append(data)
def gettimeout(self):
return None
def recv(self, bufsize):
if self.data:
e = self.data.pop(0)
if isinstance(e, Exception):
raise e
if len(e) > bufsize:
self.data.insert(0, e[bufsize:])
return e[:bufsize]
def send(self, data):
self.sent.append(data)
return len(data)
def close(self):
pass
class HeaderSockMock(SockMock):
def __init__(self, fname):
SockMock.__init__(self)
path = os.path.join(os.path.dirname(__file__), fname)
with open(path, "rb") as f:
self.add_packet(f.read())
class WebSocketTest(unittest.TestCase):
def setUp(self):
ws.enableTrace(TRACEABLE)
def tearDown(self):
pass
def test_default_timeout(self):
self.assertEqual(ws.getdefaulttimeout(), None)
ws.setdefaulttimeout(10)
self.assertEqual(ws.getdefaulttimeout(), 10)
ws.setdefaulttimeout(None)
def test_ws_key(self):
key = _create_sec_websocket_key()
self.assertTrue(key != 24)
self.assertTrue("¥n" not in key)
def test_nonce(self):
"""WebSocket key should be a random 16-byte nonce."""
key = _create_sec_websocket_key()
nonce = base64decode(key.encode("utf-8"))
self.assertEqual(16, len(nonce))
def test_ws_utils(self):
key = "c6b8hTg4EeGb2gQMztV1/g=="
required_header = {
"upgrade": "websocket",
"connection": "upgrade",
"sec-websocket-accept": "Kxep+hNu9n51529fGidYu7a3wO0=",
}
self.assertEqual(_validate_header(required_header, key, None), (True, None))
header = required_header.copy()
header["upgrade"] = "http"
self.assertEqual(_validate_header(header, key, None), (False, None))
del header["upgrade"]
self.assertEqual(_validate_header(header, key, None), (False, None))
header = required_header.copy()
header["connection"] = "something"
self.assertEqual(_validate_header(header, key, None), (False, None))
del header["connection"]
self.assertEqual(_validate_header(header, key, None), (False, None))
header = required_header.copy()
header["sec-websocket-accept"] = "something"
self.assertEqual(_validate_header(header, key, None), (False, None))
del header["sec-websocket-accept"]
self.assertEqual(_validate_header(header, key, None), (False, None))
header = required_header.copy()
header["sec-websocket-protocol"] = "sub1"
self.assertEqual(
_validate_header(header, key, ["sub1", "sub2"]), (True, "sub1")
)
# This case will print out a logging error using the error() function, but that is expected
self.assertEqual(_validate_header(header, key, ["sub2", "sub3"]), (False, None))
header = required_header.copy()
header["sec-websocket-protocol"] = "sUb1"
self.assertEqual(
_validate_header(header, key, ["Sub1", "suB2"]), (True, "sub1")
)
header = required_header.copy()
# This case will print out a logging error using the error() function, but that is expected
self.assertEqual(_validate_header(header, key, ["Sub1", "suB2"]), (False, None))
def test_read_header(self):
status, header, _ = read_headers(HeaderSockMock("data/header01.txt"))
self.assertEqual(status, 101)
self.assertEqual(header["connection"], "Upgrade")
status, header, _ = read_headers(HeaderSockMock("data/header03.txt"))
self.assertEqual(status, 101)
self.assertEqual(header["connection"], "Upgrade, Keep-Alive")
HeaderSockMock("data/header02.txt")
self.assertRaises(
ws.WebSocketException, read_headers, HeaderSockMock("data/header02.txt")
)
def test_send(self):
# TODO: add longer frame data
sock = ws.WebSocket()
sock.set_mask_key(create_mask_key)
s = sock.sock = HeaderSockMock("data/header01.txt")
sock.send("Hello")
self.assertEqual(s.sent[0], b"\x81\x85abcd)\x07\x0f\x08\x0e")
sock.send("こんにちは")
self.assertEqual(
s.sent[1],
b"\x81\x8fabcd\x82\xe3\xf0\x87\xe3\xf1\x80\xe5\xca\x81\xe2\xc5\x82\xe3\xcc",
)
# sock.send("x" * 5000)
# self.assertEqual(s.sent[1], b'\x81\x8fabcd\x82\xe3\xf0\x87\xe3\xf1\x80\xe5\xca\x81\xe2\xc5\x82\xe3\xcc")
self.assertEqual(sock.send_binary(b"1111111111101"), 19)
def test_recv(self):
# TODO: add longer frame data
sock = ws.WebSocket()
s = sock.sock = SockMock()
something = (
b"\x81\x8fabcd\x82\xe3\xf0\x87\xe3\xf1\x80\xe5\xca\x81\xe2\xc5\x82\xe3\xcc"
)
s.add_packet(something)
data = sock.recv()
self.assertEqual(data, "こんにちは")
s.add_packet(b"\x81\x85abcd)\x07\x0f\x08\x0e")
data = sock.recv()
self.assertEqual(data, "Hello")
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def test_iter(self):
count = 2
s = ws.create_connection("wss://api.bitfinex.com/ws/2")
s.send('{"event": "subscribe", "channel": "ticker"}')
for _ in s:
count -= 1
if count == 0:
break
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def test_next(self):
sock = ws.create_connection("wss://api.bitfinex.com/ws/2")
self.assertEqual(str, type(next(sock)))
def test_internal_recv_strict(self):
sock = ws.WebSocket()
s = sock.sock = SockMock()
s.add_packet(b"foo")
s.add_packet(socket.timeout())
s.add_packet(b"bar")
# s.add_packet(SSLError("The read operation timed out"))
s.add_packet(b"baz")
with self.assertRaises(ws.WebSocketTimeoutException):
sock.frame_buffer.recv_strict(9)
# with self.assertRaises(SSLError):
# data = sock._recv_strict(9)
data = sock.frame_buffer.recv_strict(9)
self.assertEqual(data, b"foobarbaz")
with self.assertRaises(ws.WebSocketConnectionClosedException):
sock.frame_buffer.recv_strict(1)
def test_recv_timeout(self):
sock = ws.WebSocket()
s = sock.sock = SockMock()
s.add_packet(b"\x81")
s.add_packet(socket.timeout())
s.add_packet(b"\x8dabcd\x29\x07\x0f\x08\x0e")
s.add_packet(socket.timeout())
s.add_packet(b"\x4e\x43\x33\x0e\x10\x0f\x00\x40")
with self.assertRaises(ws.WebSocketTimeoutException):
sock.recv()
with self.assertRaises(ws.WebSocketTimeoutException):
sock.recv()
data = sock.recv()
self.assertEqual(data, "Hello, World!")
with self.assertRaises(ws.WebSocketConnectionClosedException):
sock.recv()
def test_recv_with_simple_fragmentation(self):
sock = ws.WebSocket()
s = sock.sock = SockMock()
# OPCODE=TEXT, FIN=0, MSG="Brevity is "
s.add_packet(b"\x01\x8babcd#\x10\x06\x12\x08\x16\x1aD\x08\x11C")
# OPCODE=CONT, FIN=1, MSG="the soul of wit"
s.add_packet(b"\x80\x8fabcd\x15\n\x06D\x12\r\x16\x08A\r\x05D\x16\x0b\x17")
data = sock.recv()
self.assertEqual(data, "Brevity is the soul of wit")
with self.assertRaises(ws.WebSocketConnectionClosedException):
sock.recv()
def test_recv_with_fire_event_of_fragmentation(self):
sock = ws.WebSocket(fire_cont_frame=True)
s = sock.sock = SockMock()
# OPCODE=TEXT, FIN=0, MSG="Brevity is "
s.add_packet(b"\x01\x8babcd#\x10\x06\x12\x08\x16\x1aD\x08\x11C")
# OPCODE=CONT, FIN=0, MSG="Brevity is "
s.add_packet(b"\x00\x8babcd#\x10\x06\x12\x08\x16\x1aD\x08\x11C")
# OPCODE=CONT, FIN=1, MSG="the soul of wit"
s.add_packet(b"\x80\x8fabcd\x15\n\x06D\x12\r\x16\x08A\r\x05D\x16\x0b\x17")
_, data = sock.recv_data()
self.assertEqual(data, b"Brevity is ")
_, data = sock.recv_data()
self.assertEqual(data, b"Brevity is ")
_, data = sock.recv_data()
self.assertEqual(data, b"the soul of wit")
# OPCODE=CONT, FIN=0, MSG="Brevity is "
s.add_packet(b"\x80\x8babcd#\x10\x06\x12\x08\x16\x1aD\x08\x11C")
with self.assertRaises(ws.WebSocketException):
sock.recv_data()
with self.assertRaises(ws.WebSocketConnectionClosedException):
sock.recv()
def test_close(self):
sock = ws.WebSocket()
sock.connected = True
sock.close
sock = ws.WebSocket()
s = sock.sock = SockMock()
sock.connected = True
s.add_packet(b"\x88\x80\x17\x98p\x84")
sock.recv()
self.assertEqual(sock.connected, False)
def test_recv_cont_fragmentation(self):
sock = ws.WebSocket()
s = sock.sock = SockMock()
# OPCODE=CONT, FIN=1, MSG="the soul of wit"
s.add_packet(b"\x80\x8fabcd\x15\n\x06D\x12\r\x16\x08A\r\x05D\x16\x0b\x17")
self.assertRaises(ws.WebSocketException, sock.recv)
def test_recv_with_prolonged_fragmentation(self):
sock = ws.WebSocket()
s = sock.sock = SockMock()
# OPCODE=TEXT, FIN=0, MSG="Once more unto the breach, "
s.add_packet(
b"\x01\x9babcd.\x0c\x00\x01A\x0f\x0c\x16\x04B\x16\n\x15\rC\x10\t\x07C\x06\x13\x07\x02\x07\tNC"
)
# OPCODE=CONT, FIN=0, MSG="dear friends, "
s.add_packet(b"\x00\x8eabcd\x05\x07\x02\x16A\x04\x11\r\x04\x0c\x07\x17MB")
# OPCODE=CONT, FIN=1, MSG="once more"
s.add_packet(b"\x80\x89abcd\x0e\x0c\x00\x01A\x0f\x0c\x16\x04")
data = sock.recv()
self.assertEqual(data, "Once more unto the breach, dear friends, once more")
with self.assertRaises(ws.WebSocketConnectionClosedException):
sock.recv()
def test_recv_with_fragmentation_and_control_frame(self):
sock = ws.WebSocket()
sock.set_mask_key(create_mask_key)
s = sock.sock = SockMock()
# OPCODE=TEXT, FIN=0, MSG="Too much "
s.add_packet(b"\x01\x89abcd5\r\x0cD\x0c\x17\x00\x0cA")
# OPCODE=PING, FIN=1, MSG="Please PONG this"
s.add_packet(b"\x89\x90abcd1\x0e\x06\x05\x12\x07C4.,$D\x15\n\n\x17")
# OPCODE=CONT, FIN=1, MSG="of a good thing"
s.add_packet(b"\x80\x8fabcd\x0e\x04C\x05A\x05\x0c\x0b\x05B\x17\x0c\x08\x0c\x04")
data = sock.recv()
self.assertEqual(data, "Too much of a good thing")
with self.assertRaises(ws.WebSocketConnectionClosedException):
sock.recv()
self.assertEqual(
s.sent[0], b"\x8a\x90abcd1\x0e\x06\x05\x12\x07C4.,$D\x15\n\n\x17"
)
@unittest.skipUnless(
TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled"
)
def test_websocket(self):
s = ws.create_connection(f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}")
self.assertNotEqual(s, None)
s.send("Hello, World")
result = s.next()
s.fileno()
self.assertEqual(result, "Hello, World")
s.send("こにゃにゃちは、世界")
result = s.recv()
self.assertEqual(result, "こにゃにゃちは、世界")
self.assertRaises(ValueError, s.send_close, -1, "")
s.close()
@unittest.skipUnless(
TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled"
)
def test_ping_pong(self):
s = ws.create_connection(f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}")
self.assertNotEqual(s, None)
s.ping("Hello")
s.pong("Hi")
s.close()
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def test_support_redirect(self):
s = ws.WebSocket()
self.assertRaises(WebSocketBadStatusException, s.connect, "ws://google.com/")
# Need to find a URL that has a redirect code leading to a websocket
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def test_secure_websocket(self):
s = ws.create_connection("wss://api.bitfinex.com/ws/2")
self.assertNotEqual(s, None)
self.assertTrue(isinstance(s.sock, ssl.SSLSocket))
self.assertEqual(s.getstatus(), 101)
self.assertNotEqual(s.getheaders(), None)
s.settimeout(10)
self.assertEqual(s.gettimeout(), 10)
self.assertEqual(s.getsubprotocol(), None)
s.abort()
@unittest.skipUnless(
TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled"
)
def test_websocket_with_custom_header(self):
s = ws.create_connection(
f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}",
headers={"User-Agent": "PythonWebsocketClient"},
)
self.assertNotEqual(s, None)
self.assertEqual(s.getsubprotocol(), None)
s.send("Hello, World")
result = s.recv()
self.assertEqual(result, "Hello, World")
self.assertRaises(ValueError, s.close, -1, "")
s.close()
@unittest.skipUnless(
TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled"
)
def test_after_close(self):
s = ws.create_connection(f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}")
self.assertNotEqual(s, None)
s.close()
self.assertRaises(ws.WebSocketConnectionClosedException, s.send, "Hello")
self.assertRaises(ws.WebSocketConnectionClosedException, s.recv)
class SockOptTest(unittest.TestCase):
@unittest.skipUnless(
TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled"
)
def test_sockopt(self):
sockopt = ((socket.IPPROTO_TCP, socket.TCP_NODELAY, 1),)
s = ws.create_connection(
f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}", sockopt=sockopt
)
self.assertNotEqual(
s.sock.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY), 0
)
s.close()
class UtilsTest(unittest.TestCase):
def test_utf8_validator(self):
state = validate_utf8(b"\xf0\x90\x80\x80")
self.assertEqual(state, True)
state = validate_utf8(
b"\xce\xba\xe1\xbd\xb9\xcf\x83\xce\xbc\xce\xb5\xed\xa0\x80edited"
)
self.assertEqual(state, False)
state = validate_utf8(b"")
self.assertEqual(state, True)
class HandshakeTest(unittest.TestCase):
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def test_http_ssl(self):
websock1 = ws.WebSocket(
sslopt={"cert_chain": ssl.get_default_verify_paths().capath},
enable_multithread=False,
)
self.assertRaises(ValueError, websock1.connect, "wss://api.bitfinex.com/ws/2")
websock2 = ws.WebSocket(sslopt={"certfile": "myNonexistentCertFile"})
self.assertRaises(
FileNotFoundError, websock2.connect, "wss://api.bitfinex.com/ws/2"
)
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def test_manual_headers(self):
websock3 = ws.WebSocket(
sslopt={
"ca_certs": ssl.get_default_verify_paths().cafile,
"ca_cert_path": ssl.get_default_verify_paths().capath,
}
)
self.assertRaises(
WebSocketBadStatusException,
websock3.connect,
"wss://api.bitfinex.com/ws/2",
cookie="chocolate",
origin="testing_websockets.com",
host="echo.websocket.events/websocket-client-test",
subprotocols=["testproto"],
connection="Upgrade",
header={
"CustomHeader1": "123",
"Cookie": "TestValue",
"Sec-WebSocket-Key": "k9kFAUWNAMmf5OEMfTlOEA==",
"Sec-WebSocket-Protocol": "newprotocol",
},
)
def test_ipv6(self):
websock2 = ws.WebSocket()
self.assertRaises(ValueError, websock2.connect, "2001:4860:4860::8888")
def test_bad_urls(self):
websock3 = ws.WebSocket()
self.assertRaises(ValueError, websock3.connect, "ws//example.com")
self.assertRaises(WebSocketAddressException, websock3.connect, "ws://example")
self.assertRaises(ValueError, websock3.connect, "example.com")
if __name__ == "__main__":
unittest.main()

View File

@@ -1,199 +1,107 @@
# Python imports
from os import path
import json
# Lib imports
import gi
gi.require_version('Gtk', '3.0')
gi.require_version('GtkSource', '4')
from gi.repository import Gtk
from gi.repository import GLib
from gi.repository import GtkSource
# Application imports
from libs.dto.code.lsp.lsp_message_structs import LSPResponseTypes, LSPResponseRequest, LSPResponseNotification
from .provider import Provider
from .provider_response_cache import ProviderResponseCache
from .lsp_manager_ui import LSPManagerUI
from .lsp_manager_client import LSPManagerClient
from .response_handlers.response_registry import ResponseRegistry
class LSPManager(Gtk.Dialog):
class LSPManager:
def __init__(self):
super(LSPManager, self).__init__()
self._SCRIPT_PTH: str = path.dirname( path.realpath(__file__) )
self._USER_HOME: str = path.expanduser('~')
self._LSP_SERVERS_CONFIG: str = ""
self.servers_config: dict = {}
self.provider: Provider = Provider()
self.parent = None
self.source_view = None
self._setup_styling()
self._setup_signals()
self._subscribe_to_events()
self._init()
self._load_widgets()
self._do_bind_mapping()
def _setup_styling(self):
self.set_modal(True)
self.set_decorated(False)
self.set_vexpand(True)
self.set_hexpand(True)
def _setup_signals(self):
self.connect("show", self._show)
def _subscribe_to_events(self):
...
def _init(self):
self.provider: Provider = Provider()
self.response_cache: ProviderResponseCache = ProviderResponseCache()
self.lsp_manager_client: LSPManagerClient = LSPManagerClient()
self.response_registry: ResponseRegistry = ResponseRegistry()
def _load_widgets(self):
content_area = self.get_content_area()
self.main_box = Gtk.Grid()
self.path_entry = Gtk.SearchEntry()
self.path_bttn = Gtk.FileChooserButton.new(
title = "Workspace Folder",
action = Gtk.FileChooserAction.SELECT_FOLDER
self.lsp_manager_ui: LSPManagerUI = LSPManagerUI()
self.lsp_manager_ui.connect('create-client', self._on_create_client)
self.lsp_manager_ui.connect('close-client', self._on_close_client)
def _do_bind_mapping(self):
self.response_cache.set_lsp_client(self.lsp_manager_client)
self.provider.response_cache = self.response_cache
def _on_create_client(self, ui, lang_id: str, workspace_uri: str) -> bool:
init_opts = ui.get_init_opts(lang_id)
result = self.create_client(lang_id, workspace_uri, init_opts)
if result:
ui.toggle_client_buttons(show_close=True)
return result
def _on_close_client(self, ui, lang_id: str) -> bool:
result = self.close_client(lang_id)
if result:
ui.toggle_client_buttons(show_close=False)
return result
def create_client(
self,
lang_id: str = "python",
workspace_uri: str = "",
init_opts: dict = {}
) -> bool:
client = self.lsp_manager_client.create_client(
lang_id, workspace_uri, init_opts
)
self.combo_box = Gtk.ComboBoxText()
handler = self.response_registry.get_handler(lang_id)
self.lsp_manager_client.active_language_id = lang_id
self.hide_bttn = Gtk.Button(label = "X")
bttn_box = Gtk.Box()
create_client_bttn = Gtk.Button(label = "Create Language Client")
close_client_bttn = Gtk.Button(label = "Close Language Client")
if not client or not handler:
logger.error(f"LSP Manager: Either 'client' or 'handler' didn't get created...'")
self.close_client(lang_id)
return False
self.path_entry.set_can_focus(False)
self.path_entry.set_placeholder_text("Workspace Folder...")
self.path_entry.connect("changed", self._path_changed, bttn_box)
handler.set_context(self.response_registry)
handler.set_response_cache(self.response_cache)
self.path_bttn.connect("file-set", self._file_set)
self.path_bttn.set_halign(Gtk.Align.FILL)
self.hide_bttn.connect("clicked", lambda widget: self.hide())
create_client_bttn.connect("clicked", self.create_client, close_client_bttn)
close_client_bttn.connect("clicked", self.close_client, create_client_bttn)
client.handle_lsp_response = self.server_response
client.send_initialize_message(init_opts, "", f"file://{workspace_uri}")
self.main_box.set_column_spacing(15)
self.main_box.set_row_spacing(15)
return True
bttn_box.pack_start(create_client_bttn, False, False, 0)
bttn_box.pack_start(close_client_bttn, False, False, 0)
def close_client(self, lang_id: str) -> bool:
self.lsp_manager_client.close_client(lang_id)
self.response_registry.close_handler(lang_id)
self.main_box.attach(child = self.path_entry, left = 0, top = 0, width = 4, height = 1)
self.main_box.attach(child = self.path_bttn, left = 4, top = 0, width = 1, height = 1)
self.main_box.attach(child = self.combo_box, left = 5, top = 0, width = 1, height = 1)
self.main_box.attach(child = self.hide_bttn, left = 6, top = 0, width = 1, height = 1)
self.main_box.attach(child = bttn_box, left = 0, top = 1, width = 1, height = 1)
return True
content_area.set_vexpand(True)
content_area.set_hexpand(True)
def server_response(self, lsp_response: LSPResponseTypes):
logger.debug(f"LSP Response: { lsp_response }")
content_area.add(self.main_box)
content_area.show_all()
close_client_bttn.hide()
bttn_box.hide()
if isinstance(lsp_response, LSPResponseRequest):
if not self.lsp_manager_client.active_language_id in self.lsp_manager_client.clients:
logger.debug(f"No LSP client for '{self.lsp_manager_client.active_language_id}', skipping 'server_response'")
return
controller = self.lsp_manager_client.get_active_client()
event = controller.get_event_by_id(lsp_response.id)
handler = self.response_registry.get_handler(
self.lsp_manager_client.active_language_id, event
)
def _show(self, widget):
GLib.idle_add(self.path_entry.grab_focus)
if not handler: return
handler.handle(event, lsp_response.result, controller)
elif isinstance(lsp_response, LSPResponseNotification):
handler = self.response_registry.get_handler("default", lsp_response.method)
def _map_resize(self, widget, parent):
parent_x, parent_y = parent.get_position()
parent_width, parent_height = parent.get_size()
if parent_width == 0 or parent_height == 0: return
if not handler: return
width = int(parent_width * 0.75)
height = int(parent_height * 0.75)
widget.resize(width, height)
x = parent_x + (parent_width - width) // 2
y = parent_y + (parent_height - height) // 2
widget.move(x, y)
def _path_changed(self, widget, buttons_widget):
fpath = widget.get_text()
if not fpath:
buttons_widget.hide()
return
buttons_widget.show()
def _file_set(self, widget):
self.path_entry.set_text(
widget.get_filename()
)
self.load_lsp_servers_config_placeholders()
def map_parent_resize_event(self, parent):
parent.connect("size-allocate", lambda w, r: self._map_resize(self, parent))
def set_source_view(self, source_view):
scrolled_win = Gtk.ScrolledWindow()
lang_manager = GtkSource.LanguageManager()
buffer = source_view.get_buffer()
language = lang_manager.get_language("json")
self.source_view = source_view
buffer.set_language(language)
buffer.set_style_scheme(self.source_view.syntax_theme)
scrolled_win.set_hexpand(True)
scrolled_win.set_vexpand(True)
scrolled_win.add(self.source_view)
self.main_box.attach(child = scrolled_win, left = 0, top = 2, width = 7, height = 1)
scrolled_win.show_all()
def load_lsp_servers_config(self):
with open(f"{self._SCRIPT_PTH}/configs/lsp-servers-config.json") as file:
self._LSP_SERVERS_CONFIG = file.read()
def load_lsp_servers_config_placeholders(self):
data = self._LSP_SERVERS_CONFIG \
.replace("{user.home}", self._USER_HOME) \
.replace("{workspace.folder}", self.path_entry.get_text())
self.servers_config = json.loads(data)
buffer = self.source_view.get_buffer()
start_itr, \
end_itr = buffer.get_bounds()
buffer.delete(start_itr, end_itr)
buffer.insert(start_itr, data, -1)
self.set_language_combo_box( self.servers_config.keys() )
def set_language_combo_box(self, lang_ids: list[str]):
for lang_id in lang_ids:
self.combo_box.append_text(lang_id)
def create_client(self, widget, sibling):
buffer = self.source_view.get_buffer()
lang_id = self.combo_box.get_active_text()
if not lang_id: return
if not lang_id in self.servers_config: return
self.servers_config = json.loads( buffer.get_text( *buffer.get_bounds() ) )
init_opts = self.servers_config[lang_id]["initialization-options"]
workspace_dir = self.path_entry.get_text()
result = self.provider.response_cache.create_client(
lang_id, workspace_dir, init_opts
)
if not result: return
widget.hide()
sibling.show()
def close_client(self, widget, sibling):
lang_id = self.combo_box.get_active_text()
if not lang_id: return
result = self.provider.response_cache.close_client(lang_id)
if not result: return
widget.hide()
sibling.show()
handler.set_context(self.response_registry)
handler.set_response_cache(self.response_cache)
handler.handle(lsp_response.method, lsp_response.params, None)

View File

@@ -0,0 +1,57 @@
# Python imports
from concurrent.futures import ThreadPoolExecutor
# Lib imports
# Application imports
from .mixins.lsp_client_events_mixin import LSPClientEventsMixin
from .client.lsp_client import LSPClient
class LSPManagerClient(LSPClientEventsMixin):
def __init__(self):
super(LSPManagerClient, self).__init__()
self._cache_refresh_timeout_id: int = None
self.executor: ThreadPoolExecutor = ThreadPoolExecutor(max_workers = 1)
self.active_language_id: str = ""
self.clients: dict = {}
def create_client(
self,
lang_id: str = "python",
workspace_uri: str = "",
init_opts: dict = {}
) -> LSPClient:
if lang_id in self.clients: return None
address = "127.0.0.1"
port = 9999
uri = f"ws://{address}:{port}/{lang_id}"
client = LSPClient()
client.set_language(lang_id)
client.set_socket(uri)
client.start_client()
if not client.ws_client.wait_for_connection(timeout = 5.0):
logger.error(f"Failed to connect to LSP server for {lang_id}")
return None
self.clients[lang_id] = client
return client
def close_client(self, lang_id: str) -> bool:
if lang_id not in self.clients: return False
controller = self.clients.pop(lang_id)
controller.stop_client()
return True
def get_active_client(self) -> LSPClient:
return self.clients[self.active_language_id]

View File

@@ -0,0 +1,215 @@
# Python imports
from os import path
import json
# Lib imports
import gi
gi.require_version('Gtk', '3.0')
gi.require_version('GtkSource', '4')
from gi.repository import GObject
from gi.repository import Gtk
from gi.repository import GLib
from gi.repository import GtkSource
# Application imports
class LSPManagerUI(Gtk.Dialog):
__gsignals__ = {
'create-client': (GObject.SignalFlags.RUN_LAST, None, (str, str)),
'close-client': (GObject.SignalFlags.RUN_LAST, None, (str,)),
}
def __init__(self):
super(LSPManagerUI, self).__init__()
self._SCRIPT_PTH: str = path.dirname( path.realpath(__file__) )
self._USER_HOME: str = path.expanduser('~')
self._LSP_SERVERS_CONFIG: str = ""
self.servers_config: dict = {}
self.parent = None
self.source_view = None
self._setup_styling()
self._setup_signals()
self._subscribe_to_events()
self._load_widgets()
def _setup_styling(self):
self.set_modal(True)
self.set_decorated(False)
self.set_vexpand(True)
self.set_hexpand(True)
def _setup_signals(self):
self.connect("show", self._show)
def _subscribe_to_events(self):
...
def _load_widgets(self):
content_area = self.get_content_area()
self.main_box = Gtk.Grid()
self.path_entry = Gtk.SearchEntry()
self.path_bttn = Gtk.FileChooserButton.new(
title = "Workspace Folder",
action = Gtk.FileChooserAction.SELECT_FOLDER
)
self.combo_box = Gtk.ComboBoxText()
self.hide_bttn = Gtk.Button(label = "X")
bttn_box = Gtk.Box()
self.create_client_bttn = Gtk.Button(label = "Create Language Client")
self.close_client_bttn = Gtk.Button(label = "Close Language Client")
self.path_entry.set_can_focus(False)
self.path_entry.set_placeholder_text("Workspace Folder...")
self.path_entry.connect("changed", self._path_changed, bttn_box)
self.path_bttn.connect("file-set", self._file_set)
self.path_bttn.set_halign(Gtk.Align.FILL)
self.hide_bttn.connect("clicked", lambda widget: self.hide())
self.create_client_bttn.connect("clicked", self._create_client, self.close_client_bttn)
self.close_client_bttn.connect("clicked", self._close_client, self.create_client_bttn)
self.main_box.set_column_spacing(15)
self.main_box.set_row_spacing(15)
bttn_box.pack_start(self.create_client_bttn, False, False, 0)
bttn_box.pack_start(self.close_client_bttn, False, False, 0)
self.main_box.attach(child = self.path_entry, left = 0, top = 0, width = 4, height = 1)
self.main_box.attach(child = self.path_bttn, left = 4, top = 0, width = 1, height = 1)
self.main_box.attach(child = self.combo_box, left = 5, top = 0, width = 1, height = 1)
self.main_box.attach(child = self.hide_bttn, left = 6, top = 0, width = 1, height = 1)
self.main_box.attach(child = bttn_box, left = 0, top = 1, width = 1, height = 1)
content_area.set_vexpand(True)
content_area.set_hexpand(True)
content_area.add(self.main_box)
content_area.show_all()
self.close_client_bttn.hide()
bttn_box.hide()
def _show(self, widget):
GLib.idle_add(self.path_entry.grab_focus)
def _map_resize(self, widget, parent):
parent_x, parent_y = parent.get_position()
parent_width, parent_height = parent.get_size()
if parent_width == 0 or parent_height == 0: return
width = int(parent_width * 0.75)
height = int(parent_height * 0.75)
widget.resize(width, height)
x = parent_x + (parent_width - width) // 2
y = parent_y + (parent_height - height) // 2
widget.move(x, y)
def _path_changed(self, widget, buttons_widget):
if not widget.get_text():
buttons_widget.hide()
return
buttons_widget.show()
def _file_set(self, widget):
self.path_entry.set_text(
widget.get_filename()
)
self.load_lsp_servers_config_placeholders()
def map_parent_resize_event(self, parent):
parent.connect("size-allocate", lambda w, r: self._map_resize(self, parent))
def set_source_view(self, source_view):
scrolled_win = Gtk.ScrolledWindow()
lang_manager = GtkSource.LanguageManager()
buffer = source_view.get_buffer()
language = lang_manager.get_language("json")
self.source_view = source_view
buffer.set_language(language)
buffer.set_style_scheme(self.source_view.syntax_theme)
scrolled_win.set_hexpand(True)
scrolled_win.set_vexpand(True)
scrolled_win.add(self.source_view)
self.main_box.attach(child = scrolled_win, left = 0, top = 2, width = 7, height = 1)
scrolled_win.show_all()
def load_lsp_servers_config(self):
try:
with open(f"{self._SCRIPT_PTH}/configs/lsp-servers-config.json") as file:
self._LSP_SERVERS_CONFIG = file.read()
except FileNotFoundError:
logger.error(f"Config file not found: {self._SCRIPT_PTH}/configs/lsp-servers-config.json")
def load_lsp_servers_config_placeholders(self):
if not self._LSP_SERVERS_CONFIG: return
if not self.source_view: return
data = self._LSP_SERVERS_CONFIG \
.replace("{user.home}", self._USER_HOME) \
.replace("{workspace.folder}", self.path_entry.get_text())
self.servers_config = json.loads(data)
buffer = self.source_view.get_buffer()
start_itr, \
end_itr = buffer.get_bounds()
buffer.delete(start_itr, end_itr)
buffer.insert(start_itr, data, -1)
self.set_language_combo_box( list(self.servers_config.keys()) )
def set_language_combo_box(self, lang_ids: list[str]):
self.combo_box.remove_all()
for lang_id in lang_ids:
self.combo_box.append_text(lang_id)
def get_init_opts(self, lang_id: str) -> dict:
buffer = self.source_view.get_buffer()
try:
self.servers_config = json.loads(
buffer.get_text( *buffer.get_bounds() )
)
except json.JSONDecodeError as e:
logger.error(f"Invalid JSON: {e}")
return {}
if not lang_id or not lang_id in self.servers_config: return {}
return self.servers_config[lang_id].get("initialization-options", {})
def _create_client(self, widget, sibling):
if not self.source_view: return
buffer = self.source_view.get_buffer()
lang_id = self.combo_box.get_active_text()
if not lang_id: return
workspace_dir = self.path_entry.get_text()
self.emit('create-client', lang_id, workspace_dir)
def _close_client(self, widget, sibling):
lang_id = self.combo_box.get_active_text()
if not lang_id: return
self.emit('close-client', lang_id)
def toggle_client_buttons(self, show_close: bool):
self.create_client_bttn.set_visible(not show_close)
self.close_client_bttn.set_visible(show_close)

View File

@@ -23,7 +23,7 @@ class LSPClientEventsMixin:
uri = f"file://{fpath}" if not fpath.startswith("file://") else fpath
buffer = event.file.buffer
text = buffer.get_text(*buffer.get_bounds())
self._last_active_language_id = lang_id
self.active_language_id = lang_id
controller._lsp_did_open({
"uri": uri,
@@ -54,7 +54,7 @@ class LSPClientEventsMixin:
uri = f"file://{fpath}" if not fpath.startswith("file://") else fpath
buffer = event.file.buffer
text = buffer.get_text(*buffer.get_bounds())
self._last_active_language_id = lang_id
self.active_language_id = lang_id
controller._lsp_did_save({"uri": uri, "text": text})
@@ -71,7 +71,7 @@ class LSPClientEventsMixin:
uri = f"file://{fpath}" if not fpath.startswith("file://") else fpath
buffer = event.file.buffer
text = buffer.get_text(*buffer.get_bounds())
self._last_active_language_id = lang_id
self.active_language_id = lang_id
controller._lsp_did_change({
"uri": uri,
@@ -97,7 +97,7 @@ class LSPClientEventsMixin:
controller = self.clients[lang_id]
uri = f"file://{fpath}" if not fpath.startswith("file://") else fpath
self._last_active_language_id = lang_id
self.active_language_id = lang_id
controller._lsp_definition({
"uri": uri,
@@ -116,7 +116,7 @@ class LSPClientEventsMixin:
controller = self.clients[lang_id]
uri = f"file://{fpath}" if not fpath.startswith("file://") else fpath
self._last_active_language_id = lang_id
self.active_language_id = lang_id
controller._lsp_completion({
"uri": uri,

View File

@@ -1,91 +0,0 @@
# Python imports
# Lib imports
import gi
gi.require_version('GtkSource', '4')
from gi.repository import GLib
from gi.repository import GtkSource
# Application imports
from libs.event_factory import Event_Factory, Code_Event_Types
class LSPServerEventsMixin:
def _handle_definition_response(self, uri: str, pointer_pos: dict):
self._prompt_goto_request(uri, pointer_pos)
def _handle_completion_response(self, result: dict or list):
if not result: return
items = []
if isinstance(result, dict):
items = result.get("items", [])
elif isinstance(result, list):
items = result
self.matchers.clear()
for item in items:
label = item.get("label", "")
if not label: continue
text = item.get("insertText")
if not text and "textEdit" in item:
text = item["textEdit"].get("newText", "")
info = ""
if "detail" in item:
info = item["detail"]
elif "documentation" in item:
doc = item["documentation"]
if isinstance(doc, dict):
info = doc.get("value", "")
else:
info = str(doc)
self.matchers[label] = {
"label": label,
"text": text,
"info": info
}
self._prompt_completion_request()
def _prompt_goto_request(self, uri: str, pointer_pos: dict):
event = Event_Factory.create_event(
"get_active_view",
)
self.emit_to("source_views", event)
view = event.response
view._on_uri_data_received( [uri] )
buffer = view.get_buffer()
def move_cursor(buffer, pointer_pos):
itr = buffer.get_iter_at_line( pointer_pos["end"]["line"] )
itr.forward_chars( pointer_pos["end"]["character"] )
buffer.place_cursor(itr)
view.scroll_to_iter(itr, 0.2, False, 0, 0)
GLib.idle_add( move_cursor, buffer, pointer_pos )
def _handle_java_class_file_contents(self, text: str):
event = Event_Factory.create_event(
"get_active_view",
)
self.emit_to("source_views", event)
view = event.response
file = view.command.exec("new_file")
buffer = view.get_buffer()
itr = buffer.get_iter_at_mark( buffer.get_insert() )
lm = GtkSource.LanguageManager.get_default()
language = lm.get_language("java")
file.ftype = "java"
buffer.set_language(language)
buffer.insert(itr, text, -1)

View File

@@ -1,6 +1,9 @@
# Python imports
# Lib imports
import gi
from gi.repository import GLib
# Application imports
from libs.event_factory import Event_Factory, Code_Event_Types
@@ -27,7 +30,7 @@ class Plugin(PluginCode):
def load(self):
window = self.request_ui_element("main-window")
lsp_manager.map_parent_resize_event(window)
lsp_manager.lsp_manager_ui.map_parent_resize_event(window)
event = Event_Factory.create_event("register_command",
command_name = "LSP Manager",
@@ -37,7 +40,7 @@ class Plugin(PluginCode):
)
self.emit_to("source_views", event)
event = Event_Factory.create_event(
event = Event_Factory.create_event(
"register_provider",
provider_name = "LSP Completer",
provider = lsp_manager.provider,
@@ -52,12 +55,11 @@ class Plugin(PluginCode):
self.emit_to("source_views", event)
source_view = event.response
lsp_manager.load_lsp_servers_config()
lsp_manager.set_source_view(source_view)
lsp_manager.load_lsp_servers_config_placeholders()
lsp_manager.provider.response_cache.emit = self.emit
lsp_manager.provider.response_cache.emit_to = self.emit_to
lsp_manager.provider.response_cache._prompt_completion_request = self._prompt_completion_request
lsp_manager.lsp_manager_ui.load_lsp_servers_config()
lsp_manager.lsp_manager_ui.set_source_view(source_view)
lsp_manager.lsp_manager_ui.load_lsp_servers_config_placeholders()
lsp_manager.response_registry.set_event_hub(self.emit, self.emit_to, lsp_manager.provider)
def run(self):
...
@@ -65,20 +67,6 @@ class Plugin(PluginCode):
def generate_plugin_element(self):
...
def _prompt_completion_request(self):
event = Event_Factory.create_event(
"get_active_view",
)
self.emit_to("source_views", event)
view = event.response
event = Event_Factory.create_event(
"request_completion",
view = view,
provider = lsp_manager.provider
)
self.emit_to("completion", event)
class Handler:
@staticmethod
@@ -98,7 +86,7 @@ class Handler:
column = iter.get_line_offset()
if char_str == "g":
lsp_manager.provider.response_cache.process_goto_definition(
lsp_manager.lsp_manager_client.process_goto_definition(
file.ftype, file.fpath, line, column
)
@@ -107,4 +95,4 @@ class Handler:
if char_str == "i":
return
lsp_manager.hide() if lsp_manager.is_visible() else lsp_manager.show()
lsp_manager.lsp_manager_ui.hide() if lsp_manager.lsp_manager_ui.is_visible() else lsp_manager.lsp_manager_ui.show()

View File

@@ -0,0 +1,2 @@
from .provider import Provider
from .provider_response_cache import ProviderResponseCache

View File

@@ -22,7 +22,7 @@ class Provider(GObject.GObject, GtkSource.CompletionProvider):
def __init__(self):
super(Provider, self).__init__()
self.response_cache: ProviderResponseCache = ProviderResponseCache()
self.response_cache: ProviderResponseCache = None
def pre_populate(self, context):
@@ -32,13 +32,19 @@ class Provider(GObject.GObject, GtkSource.CompletionProvider):
return "LSP Code Completion"
def do_match(self, context):
# Note: If provider is in interactive activation then need to check
# view focus as otherwise non focus views start trying to grab it.
# completion = context.get_property("completion")
# if not completion.get_view().has_focus(): return
iter = self.response_cache.get_iter_correctly(context)
iter.backward_char()
ch = iter.get_char()
# NOTE: Look to re-add or apply supprting logic to use spaces
# NOTE: Look to re-add or apply supporting logic to use spaces
# As is it slows down the editor in certain contexts...
if not (ch in ('_', '.', ' ') or ch.isalnum()):
# if not (ch in ('_', '.', ' ') or ch.isalnum()):
if not (ch in ('_', '.') or ch.isalnum()):
return False
buffer = iter.get_buffer()

View File

@@ -0,0 +1,44 @@
# Python imports
# Lib imports
import gi
gi.require_version('GtkSource', '4')
from gi.repository import GtkSource
# Application imports
from core.widgets.code.completion_providers.provider_response_cache_base import ProviderResponseCacheBase
class ProviderResponseCache(ProviderResponseCacheBase):
def __init__(self):
super(ProviderResponseCache, self).__init__()
self.matchers: dict = {}
self._lsp_client = None
def set_lsp_client(self, lsp_client):
self._lsp_client = lsp_client
def process_file_load(self, event):
if self._lsp_client:
self._lsp_client.process_file_load(event)
def process_file_close(self, event):
if self._lsp_client:
self._lsp_client.process_file_close(event)
def process_file_save(self, event):
if self._lsp_client:
self._lsp_client.process_file_save(event)
def process_file_change(self, event):
if self._lsp_client:
self._lsp_client.process_file_change(event)
def filter(self, word: str) -> list[dict]:
return []
def filter_with_context(self, context: GtkSource.CompletionContext) -> list[dict]:
return list( self.matchers.values() )

View File

@@ -1,126 +0,0 @@
# Python imports
from concurrent.futures import ThreadPoolExecutor
import asyncio
from asyncio import Queue
# Lib imports
import gi
gi.require_version('GtkSource', '4')
from gi.repository import GtkSource
# Application imports
from libs.dto.code.lsp.lsp_message_structs import LSPResponseTypes, LSPResponseRequest, LSPResponseNotification
from core.widgets.code.completion_providers.provider_response_cache_base import ProviderResponseCacheBase
from .controllers.lsp_controller import LSPController
from .mixins.lsp_client_events_mixin import LSPClientEventsMixin
from .mixins.lsp_server_events_mixin import LSPServerEventsMixin
class ProviderResponseCache(LSPClientEventsMixin, LSPServerEventsMixin, ProviderResponseCacheBase):
def __init__(self):
super(ProviderResponseCache, self).__init__()
self.executor = ThreadPoolExecutor(max_workers = 1)
self.matchers: dict = {}
self.clients: dict = {}
self._cache_refresh_timeout_id: int = None
self._last_active_language_id: str = None
def create_client(
self,
lang_id: str = "python",
workspace_uri: str = "",
init_opts: dict = {
}) -> bool:
if lang_id in self.clients: return False
address = "127.0.0.1"
port = 9999
uri = f"ws://{address}:{port}/{lang_id}"
controller = LSPController()
controller.handle_lsp_response = self.server_response
controller.set_language(lang_id)
controller.set_socket(uri)
controller.start_client()
if not controller.ws_client.wait_for_connection(timeout = 5.0):
logger.error(f"Failed to connect to LSP server for {lang_id}")
return False
self.clients[lang_id] = controller
controller.send_initialize_message(init_opts, "", f"file://{workspace_uri}")
return True
def close_client(self, lang_id: str) -> bool:
if lang_id not in self.clients: return False
controller = self.clients.pop(lang_id)
controller.stop_client()
return True
# TODO: Need to map 'lang_id' to a given language response class and
# pass the controller to a 'server_response' method there.
# It would allow clean separation of each language's idiosyncracies
def server_response(self, lsp_response: LSPResponseTypes):
logger.debug(f"LSP Response: { lsp_response }")
if isinstance(lsp_response, LSPResponseRequest):
if not self._last_active_language_id in self.clients:
logger.debug(f"No LSP client for '{self._last_active_language_id}', skipping 'server_response'")
return
controller = self.clients[self._last_active_language_id]
event = controller.get_event_by_id(lsp_response.id)
match event:
case "textDocument/completion":
self._handle_completion_response(lsp_response.result)
case "textDocument/definition":
result = lsp_response.result
if not result: return
uri = result[0]["uri"]
if "jdt://" in uri:
controller._lsp_java_class_file_contents(uri)
return
self._handle_definition_response(uri, result[0]["range"])
case "java/classFileContents":
self._handle_java_class_file_contents(lsp_response.result)
case _:
...
elif isinstance(lsp_response, LSPResponseNotification):
match lsp_response.method:
case "textDocument/publishDiagnostics":
...
case _:
...
def filter(self, word: str) -> list[dict]:
return []
def filter_with_context(self, context: GtkSource.CompletionContext) -> list[dict]:
response = []
iter = self.get_iter_correctly(context)
iter.backward_char()
char_str = iter.get_char()
if char_str == "." or char_str == " ":
for label, item in self.matchers.items():
response.append(item)
return response
word = self.get_word(context).rstrip()
for label, item in self.matchers.items():
if label.startswith(word):
response.append(item)
return response

View File

@@ -0,0 +1,5 @@
from .base import BaseHandler
from .default import DefaultHandler
from .python import PythonHandler
from .java import JavaHandler
from .response_registry import ResponseRegistry

View File

@@ -0,0 +1,30 @@
# Python imports
# Lib imports
# Application imports
class BaseHandler:
def __init__(self):
self.context = None
self.response_cache = None
def set_context(self, context):
self.context = context
def set_response_cache(self, response_cache):
self.response_cache = response_cache
@property
def emit(self):
return self.context.emit
@property
def emit_to(self):
return self.context.emit_to
def handle(self, method: str, response, controller):
pass

View File

@@ -0,0 +1,134 @@
# Python imports
# Lib imports
import gi
from gi.repository import GLib
# Application imports
from libs.event_factory import Event_Factory, Code_Event_Types
from .base import BaseHandler
class DefaultHandler(BaseHandler):
"""Fallback handler for unknown languages - uses generic LSP handling."""
def handle(self, method: str, response, controller):
match method:
case "textDocument/completion":
self._handle_completion(response)
case "textDocument/definition":
self._handle_definition(response, controller)
case "textDocument/publishDiagnostics":
self._handle_diagnostics(response)
def _handle_completion(self, result):
if not result: return
items = result.get("items", []) if isinstance(result, dict) else result
self.response_cache.matchers.clear()
for item in items:
label = item.get("label")
if not label:
continue
text = (
item.get("insertText")
or item.get("textEdit", {}).get("newText")
or item.get("textEditText", "")
or label
)
detail = item.get("detail")
doc = item.get("documentation")
if detail:
info = detail
elif isinstance(doc, dict):
info = doc.get("value", "")
else:
info = str(doc) if doc else ""
self.response_cache.matchers[label] = {
"label": label,
"text": text,
"info": info,
}
self._prompt_completion_request()
def _handle_definition(self, response, controller):
if not response: return
uri = response[0]["uri"]
self._prompt_goto_request(uri, response[0]["range"])
def _handle_diagnostics(self, params):
if not params: return
uri = params.get("uri", "")
diagnostics = params.get("diagnostics", [])
errors = []
warnings = []
hints = []
for diag in diagnostics:
severity = diag.get("severity", 1)
message = diag.get("message", "")
range = diag.get("range", {})
diag_info = {
"message": message,
"range": range
}
if severity == 1:
errors.append(diag_info)
elif severity == 2:
warnings.append(diag_info)
elif severity == 3:
hints.append(diag_info)
self.response_cache.lsp_diagnostics = {
"uri": uri,
"errors": errors,
"warnings": warnings,
"hints": hints
}
logger.debug(f"LSP Diagnostics for {uri}: {len(errors)} errors, {len(warnings)} warnings, {len(hints)} hints")
def _prompt_goto_request(self, uri: str, pointer_pos: dict):
event = Event_Factory.create_event(
"get_active_view",
)
self.emit_to("source_views", event)
view = event.response
view._on_uri_data_received( [uri] )
buffer = view.get_buffer()
def move_cursor(buffer, pointer_pos):
itr = buffer.get_iter_at_line( pointer_pos["end"]["line"] )
itr.forward_chars( pointer_pos["end"]["character"] )
buffer.place_cursor(itr)
view.scroll_to_iter(itr, 0.2, False, 0, 0)
GLib.idle_add( move_cursor, buffer, pointer_pos )
def _prompt_completion_request(self):
event = Event_Factory.create_event("get_active_view")
self.emit_to("source_views", event)
view = event.response
event = Event_Factory.create_event(
"request_completion",
view = view,
provider = self.context._provider
)
self.emit_to("completion", event)

View File

@@ -0,0 +1,51 @@
# Python imports
# Lib imports
import gi
gi.require_version('GtkSource', '4')
from gi.repository import GtkSource
# Application imports
from libs.event_factory import Event_Factory, Code_Event_Types
from .default import DefaultHandler
class JavaHandler(DefaultHandler):
"""Java-specific: overrides definition, handles classFileContents."""
def handle(self, method: str, response, controller):
match method:
case "textDocument/definition":
self._handle_definition(response, controller)
case "java/classFileContents":
self._handle_class_file_contents(response)
case _:
super().handle(method, response, controller)
def _handle_definition(self, response, controller):
if not response: return
uri = response[0]["uri"]
if "jdt://" in uri:
controller._lsp_java_class_file_contents(uri)
return
self._prompt_goto_request(uri, response[0]["range"])
def _handle_class_file_contents(self, text: str):
event = Event_Factory.create_event("get_active_view")
self.emit_to("source_views", event)
view = event.response
file = view.command.exec("new_file")
buffer = view.get_buffer()
itr = buffer.get_iter_at_mark(buffer.get_insert())
lm = GtkSource.LanguageManager.get_default()
language = lm.get_language("java")
file.ftype = "java"
buffer.set_language(language)
buffer.insert(itr, text, -1)

View File

@@ -0,0 +1,12 @@
# Python imports
# Lib imports
# Application imports
from .default import DefaultHandler
class PythonHandler(DefaultHandler):
"""Uses default handling, can override if Python needs special logic."""
...

View File

@@ -0,0 +1,52 @@
# Python imports
# Lib imports
# Application imports
from .base import BaseHandler
from .default import DefaultHandler
from .python import PythonHandler
from .java import JavaHandler
class ResponseRegistry:
def __init__(self):
self._instances: dict = {}
self._lang_handlers: dict = {
"default": DefaultHandler,
"python": PythonHandler,
"java": JavaHandler,
}
def set_event_hub(self, emit, emit_to, provider=None):
self.emit = emit
self.emit_to = emit_to
self._provider = provider
def _get_instance(self, handler_cls: type[BaseHandler]) -> BaseHandler:
if handler_cls in self._instances: return self._instances[handler_cls]
self._instances[handler_cls] = handler_cls()
return self._instances[handler_cls]
def register_handler(self, lang_id: str, handler_cls: type[BaseHandler]):
self._lang_handlers[lang_id] = handler_cls
def get_handler(self, lang_id: str = "", method: str = ""):
handler_cls = self._lang_handlers.get(
lang_id, self._lang_handlers.get("default", DefaultHandler)
)
if not handler_cls: return None
return self._get_instance(handler_cls)
def close_handler(self, lang_id: str):
if not lang_id in self._lang_handlers: return
handler_cls = self._lang_handlers[lang_id]
self._instances.pop(handler_cls, None)

View File

@@ -0,0 +1,8 @@
#!/bin/bash
# set -o xtrace ## To debug scripts
# set -o errexit ## To exit on error
# set -o errunset ## To exit if a variable is referenced but not set
CONTAINER="newton-lsp"

View File

@@ -0,0 +1,38 @@
#!/bin/bash
. CONFIG.sh
# set -o xtrace ## To debug scripts
# set -o errexit ## To exit on error
# set -o errunset ## To exit if a variable is referenced but not set
function main() {
SCRIPTPATH="$( cd "$(dirname "$0")" >/dev/null 2>&1 ; pwd -P )"
cd "${SCRIPTPATH}"
echo "Working Dir: " $(pwd)
ID=$(podman ps --filter "ancestor=localhost/${CONTAINER}:latest" --format "{{.ID}}")
if [ "${ID}" != "" ]; then
echo "Is up..."
exit 1
fi
CODE_HOST="${HOME}/Coding"
CODE_CONTAINER="${HOME}/Coding"
CONFIG_HOST="${HOME}/.config/lsps"
CONFIG_CONTAINER="${HOME}/.config/lsps"
# podman run -d -m 4G \
podman run -m 4G \
-p 9999:9999 \
-e HOME="${HOME}" \
-e MAVEN_OPTS="-Duser.home=${HOME}" \
-e JAVA_TOOL_OPTIONS="-Duser.home=${HOME}" \
-e JDTLS_CONFIG_PATH="${CONFIG_CONTAINER}/jdtls" \
-e JDTLS_DATA_PATH="${JDTLS_CONFIG_PATH}/data" \
-v "${CODE_HOST}":"${CODE_CONTAINER}" \
-v "${CONFIG_HOST}":"${CONFIG_CONTAINER}" \
"${CONTAINER}:latest"
}
main $@;

View File

@@ -0,0 +1,23 @@
#!/bin/bash
. CONFIG.sh
# set -o xtrace ## To debug scripts
# set -o errexit ## To exit on error
# set -o errunset ## To exit if a variable is referenced but not set
function main() {
SCRIPTPATH="$( cd "$(dirname "$0")" >/dev/null 2>&1 ; pwd -P )"
cd "${SCRIPTPATH}"
echo "Working Dir: " $(pwd)
ID=$(podman ps --filter "ancestor=localhost/${CONTAINER}:latest" --format "{{.ID}}")
if [ "${ID}" == "" ]; then
echo "Is not up..."
exit 1
fi
podman container stop "${ID}"
}
main $@;

View File

@@ -9,6 +9,7 @@ from gi.repository import GtkSource
from gi.repository import Gtk
from gi.repository import Gdk
from gi.repository import Gio
from gi.repository import GLib
# Application imports
from ..command_helpers import update_info_bar_if_focused
@@ -35,6 +36,10 @@ def execute(
update_info_bar_if_focused(view.command, view)
view.emit("focus-in-event", Gdk.Event())
buffer = view.get_buffer()
itr = buffer.get_iter_at_mark( buffer.get_insert() )
view.scroll_to_iter(itr, 0.2, False, 0, 0)
def scroll_to_insert_itr(view):
buffer = view.get_buffer()
itr = buffer.get_iter_at_mark( buffer.get_insert() )
view.scroll_to_iter(itr, 0.2, False, 0, 0)
GLib.idle_add(scroll_to_insert_itr, view)