From b30592351b7f331c0f8c70c8a1ea3ba03113d456 Mon Sep 17 00:00:00 2001
From: Nico Alt <nicoalt@posteo.org>
Date: Sun, 5 Apr 2020 13:00:01 +0000
Subject: [PATCH] Introduce connect function to socket listener

Similar to GTK's signals, listeners for events can now be added to
SocketListener.
---
 briar_wrapper/api.py                    | 11 +++---
 briar_wrapper/models/contacts.py        |  9 -----
 briar_wrapper/models/private_chat.py    |  8 -----
 briar_wrapper/models/socket_listener.py | 46 ++++++++++++++++++-------
 4 files changed, 40 insertions(+), 34 deletions(-)

diff --git a/briar_wrapper/api.py b/briar_wrapper/api.py
index cfc04d2..6df2ec8 100644
--- a/briar_wrapper/api.py
+++ b/briar_wrapper/api.py
@@ -23,8 +23,6 @@ class Api:
     def __init__(self, headless_jar):
         self._command = ["java", "-jar", headless_jar]
 
-        self.socket_listener = SocketListener(self)
-
     @staticmethod
     def has_account():
         return isfile(BRIAR_DB)
@@ -67,14 +65,17 @@ class Api:
                 sleep(0.1)
             except HTTPError as http_error:
                 if http_error.code == 404:
-                    self._load_auth_token()
-                    callback(True)
-                    return
+                    return self._on_successful_startup(callback)
             except URLError as url_error:
                 if not isinstance(url_error.reason, ConnectionRefusedError):
                     raise url_error
         callback(False)
 
+    def _on_successful_startup(self, callback):
+        self._load_auth_token()
+        self.socket_listener = SocketListener(self)
+        callback(True)
+
     def _login(self, password):
         if not self.is_running():
             raise Exception("Can't login; API not running")
diff --git a/briar_wrapper/models/contacts.py b/briar_wrapper/models/contacts.py
index adcd584..b256a8c 100644
--- a/briar_wrapper/models/contacts.py
+++ b/briar_wrapper/models/contacts.py
@@ -34,12 +34,3 @@ class Contacts(Model):
         url = urljoin(BASE_HTTP_URL, self.API_ENDPOINT + "add/" + "link/")
         request = _get(url, headers=self._headers).json()
         return request['link']
-
-    def watch_contacts(self, callback):
-        self._on_contact_added_callback = callback
-        self._api.socket_listener.watch("ContactAddedEvent",
-                                        self._on_contact_added)
-
-    # pylint: disable=unused-argument
-    def _on_contact_added(self, event):
-        self._on_contact_added_callback()
diff --git a/briar_wrapper/models/private_chat.py b/briar_wrapper/models/private_chat.py
index 1b5e4df..2067fbe 100644
--- a/briar_wrapper/models/private_chat.py
+++ b/briar_wrapper/models/private_chat.py
@@ -28,14 +28,6 @@ class PrivateChat(Model):
         request = _get(url, headers=self._headers)
         return request.json()
 
-    def watch_messages(self, callback):
-        self._on_message_received_callback = callback
-        self._api.socket_listener.watch("ConversationMessageReceivedEvent",
-                                        self._on_message_received)
-
-    def _on_message_received(self, event):
-        self._on_message_received_callback(event['data'])
-
     def send(self, message):
         url = urljoin(BASE_HTTP_URL,
                       self.API_ENDPOINT + "/%i" % self._contact_id)
diff --git a/briar_wrapper/models/socket_listener.py b/briar_wrapper/models/socket_listener.py
index 84e2408..0860e16 100644
--- a/briar_wrapper/models/socket_listener.py
+++ b/briar_wrapper/models/socket_listener.py
@@ -4,7 +4,7 @@
 
 import asyncio
 import json
-from threading import Thread
+from threading import Thread, Lock
 
 import websockets
 
@@ -12,33 +12,55 @@ from briar_wrapper.constants import WEBSOCKET_URL
 from briar_wrapper.model import Model
 
 
-class SocketListener(Model):  # pylint: disable=too-few-public-methods
+class SocketListener():  # pylint: disable=too-few-public-methods
 
-    def watch(self, event, callback):
+    def __init__(self, api):
+        self._api = api
+        self._signals = list()
+        self._signals_lock = Lock()
+        self._start_websocket_thread()
+
+    def connect(self, event, callback):
+        self._signals_lock.acquire()
+        # TODO: Signal ID should be stable after disconnects
+        signal_id = len(self._signals)
+        self._signals.append({
+            "event": event,
+            "callback": callback
+            })
+        self._signals_lock.release()
+        return signal_id
+
+    def _start_websocket_thread(self):
         websocket_thread = Thread(target=self._start_watch_loop,
-                                  args=(event, callback),
                                   daemon=True)
         websocket_thread.start()
 
-    def _start_watch_loop(self, event, callback):
+    def _start_watch_loop(self):
         loop = asyncio.new_event_loop()
         asyncio.set_event_loop(loop)
-        loop.create_task(self._start_websocket(event, callback))
+        loop.create_task(self._start_websocket())
         loop.run_forever()
         loop.close()
 
-    async def _start_websocket(self, event, callback):
+    async def _start_websocket(self):
         async with websockets.connect(WEBSOCKET_URL) as websocket:
             await websocket.send(self._api.auth_token)
-            await self._watch_messages(websocket, event, callback)
+            await self._watch_messages(websocket)
 
-    async def _watch_messages(self, websocket, event, callback):
+    async def _watch_messages(self, websocket):
         while not websocket.closed and not\
                 asyncio.get_event_loop().is_closed():
             message_json = await websocket.recv()
             message = json.loads(message_json)
-            if message['name'] == event:
-                callback(message)
+            self._call_signal_callbacks(message)
         if not asyncio.get_event_loop().is_closed():
             asyncio.get_event_loop().create_task(
-                self._watch_messages(websocket, event, callback))
+                self._watch_messages(websocket))
+
+    def _call_signal_callbacks(self, message):
+            self._signals_lock.acquire()
+            for signal in self._signals:
+                if signal["event"] == message['name']:
+                    signal["callback"](message)
+            self._signals_lock.release()
-- 
GitLab