from concurrent.futures import ThreadPoolExecutor import asyncio import random import websockets import json from lib import config as config_module config = config_module.config from lib import logging as logging_lib from lib import utils from clore_hosting import utils as clore_utils log = logging_lib.log async def run_command_via_executor(command): loop = asyncio.get_running_loop() with ThreadPoolExecutor() as pool: await loop.run_in_executor(pool, utils.run_command_v2, command) def trim_container_log(string): if len(string) > config.max_container_log_size: return string[-config.max_container_log_size:] else: return string class WebSocketClient: def __init__(self, log_message_broker, auth=None): self.ws_peers = [] self.connection = None self.connected = False self.authorized = False self.auth = auth self.log_auth_fail = True self.last_heartbeat = clore_utils.unix_timestamp() self.containers={} self.containers_set=False self.pull_logs={} self.pull_logs_last_fingerprints={} self.to_stream={} self.log_message_broker=log_message_broker self.allowed_log_container_names = [] self.current_container_logs = {} self.last_bash_rnd = '' self.oc_enabled = False self.last_gpu_oc_specs = [] self.last_set_oc = {} def get_last_heartbeat(self): return self.last_heartbeat def get_containers(self): return self.containers_set, self.containers def get_oc(self): return self.oc_enabled, self.last_gpu_oc_specs, self.last_set_oc def set_ws_peers(self, ws_peers): tmp_ws_peers=[] for ws_peer in list(ws_peers.keys()): if clore_utils.is_valid_websocket_url(ws_peer): tmp_ws_peers.append(ws_peer) self.ws_peers = tmp_ws_peers def set_auth(self, auth): self.auth=auth def set_pull_logs(self, pull_logs): self.pull_logs=pull_logs async def close_websocket(self, timeout=5): try: await asyncio.wait_for(self.connection.close(), timeout) except asyncio.TimeoutError: log.debug("close_websocket() | Closing timed out. Forcing close.") try: await self.connection.ensure_open() # Force close except Exception as e: pass async def connect(self): if len(self.ws_peers)>0 and self.auth: random_ws_peer = random.choice(self.ws_peers) try: self.connection = await websockets.connect(random_ws_peer) self.connected = True log.debug(f"CLOREWS | Connected to {random_ws_peer} ✅") await self.send(json.dumps({ "login":str(self.auth), "type":"python" })) except Exception as e: log.debug(f"CLOREWS | Connection to {random_ws_peer} failed: {e} ❌") self.connected = False self.authorized = False self.pull_logs_last_fingerprints={} async def send(self, message): try: if self.connection and self.connected: if type(message)==dict: message=json.dumps(message) await self.connection.send(message) log.debug(f"CLOREWS | Message sent: {message}") return True else: return False except Exception as e: return False async def receive(self): while self.connected: try: message = await self.connection.recv() if message=="NOT_AUTHORIZED" and self.log_auth_fail: self.log_auth_fail = False log.error("🔑 Invalid auth key for clore.ai") elif message=="AUTHORIZED": self.log_auth_fail = True self.containers_set = False self.last_heartbeat = clore_utils.unix_timestamp() self.authorized=True log.success("🔑 Authorized with clore.ai") try: current_container_logs_keys = self.current_container_logs.keys() for container_name in current_container_logs_keys: await self.send(json.dumps({"container_log": self.current_container_logs[container_name], "type":"set", "container_name":container_name})) except Exception as ei: pass elif message=="KEEPALIVE": self.last_heartbeat = clore_utils.unix_timestamp() elif message=="NEWER_LOGIN" or message=="WAIT": await self.close_websocket() elif message[:10]=="PROVEPULL;": parts = message.split(';') if len(parts)==3 and parts[1] in self.to_stream: current_log = self.to_stream[parts[1]] current_log_hash = utils.hash_md5(current_log) if current_log_hash==parts[2]: del self.to_stream[parts[1]] else: try: parsed_json = json.loads(message) if "type" in parsed_json and parsed_json["type"]=="set_containers" and "new_containers" in parsed_json and type(parsed_json["new_containers"])==list: self.last_heartbeat = clore_utils.unix_timestamp() container_str = json.dumps({"containers":parsed_json["new_containers"]}) await self.send(container_str) if len(parsed_json["new_containers"]) > 0: # There should be at least one container self.containers_set = True self.containers=parsed_json["new_containers"] #log.success(container_str) elif "allow_oc" in parsed_json: # Enable OC self.oc_enabled=True await self.send(json.dumps({"allow_oc":True})) elif "gpu_oc_info" in parsed_json: self.last_gpu_oc_specs = parsed_json["gpu_oc_info"] elif "set_oc" in parsed_json: # Set specific OC self.last_set_oc=parsed_json["set_oc"] back_oc_str = json.dumps({"current_oc":json.dumps(parsed_json["set_oc"], separators=(',',':'))}) await self.send(back_oc_str) elif "bash_cmd" in parsed_json and type(parsed_json["bash_cmd"])==str and "bash_rnd" in parsed_json: await self.send(json.dumps({"bash_rnd":parsed_json["bash_rnd"]})) if self.last_bash_rnd!=parsed_json["bash_rnd"]: self.last_bash_rnd=parsed_json["bash_rnd"] asyncio.create_task(run_command_via_executor(parsed_json["bash_cmd"])) except Exception as e: log.error(f"CLOREWS | JSON | {e}") #log.success(f"Message received: {message}") # Handle received message except websockets.exceptions.ConnectionClosed: log.debug("CLOREWS | Connection closed, attempting to reconnect...") self.connected = False self.authorized = False self.pull_logs_last_fingerprints={} self.containers_set = False async def stream_pull_logs(self): if self.authorized: #self.pull_logs_last_fingerprints for image_str in self.pull_logs.keys(): value = self.pull_logs[image_str] last_hash = self.pull_logs_last_fingerprints[image_str] if image_str in self.pull_logs_last_fingerprints.keys() else '' if "log" in value: current_hash = utils.hash_md5(value["log"]) if last_hash != current_hash: self.pull_logs_last_fingerprints[image_str]=current_hash self.to_stream[image_str] = value["log"] # This makes sure, that each time it will submit the most recent version ttl_submited_characters=0 for index, image in enumerate(self.to_stream.keys()): try: if index < config.max_pull_logs_per_submit_run["instances"] and ttl_submited_characters <= config.max_pull_logs_per_submit_run["size"]: submit_log = self.to_stream[image] to_submit_log = submit_log[-config.max_pull_log_size:] ttl_submited_characters+=len(to_submit_log) await self.send({ "pull_log":to_submit_log, "image":image }) except Exception as e: log.error(e) return True async def stream_container_logs(self): got_data=[] while not self.log_message_broker.empty(): got_data.append(await self.log_message_broker.get()) #print("GOT DATA", got_data) if len(got_data) > 0: for data_sample in got_data: if type(data_sample)==list: self.allowed_log_container_names=data_sample elif type(data_sample)==str and '|' in data_sample: container_name, data = data_sample.split('|',1) if container_name in self.allowed_log_container_names: log_container_names = self.current_container_logs.keys() if data=='I': if container_name in log_container_names: del self.current_container_logs[container_name] else: log_txt = data[1:] if container_name in log_container_names: self.current_container_logs[container_name]+=log_txt await self.send(json.dumps({"container_log":log_txt, "type":"append", "container_name":container_name})) else: self.current_container_logs[container_name]=log_txt await self.send(json.dumps({"container_log":log_txt, "type":"set", "container_name":container_name})) if len(self.current_container_logs[container_name]) > config.max_container_log_size: self.current_container_logs[container_name]=trim_container_log(self.current_container_logs[container_name]) container_log_in_cache_names = self.current_container_logs.keys() for container_in_cache_name in container_log_in_cache_names: if not container_in_cache_name in self.allowed_log_container_names: del self.current_container_logs[container_in_cache_name] async def ensure_connection(self): if not self.connected: await self.connect() async def run(self): while True: await self.connect() receive_task = asyncio.create_task(self.receive()) await receive_task log.debug("CLOREWS | Waiting to reconnect WS") await asyncio.sleep(2)