hosting/clore_hosting/ws_interface.py

254 lines
11 KiB
Python
Raw Permalink Normal View History

2024-03-21 01:28:02 +00:00
from concurrent.futures import ThreadPoolExecutor
import asyncio
import random
import websockets
import json
from lib import config as config_module
config = config_module.config
from lib import logging as logging_lib
from lib import utils
from clore_hosting import utils as clore_utils
log = logging_lib.log
async def run_command_via_executor(command):
loop = asyncio.get_running_loop()
with ThreadPoolExecutor() as pool:
await loop.run_in_executor(pool, utils.run_command_v2, command)
def trim_container_log(string):
if len(string) > config.max_container_log_size:
return string[-config.max_container_log_size:]
else:
return string
class WebSocketClient:
def __init__(self, log_message_broker, auth=None):
self.ws_peers = []
self.connection = None
self.connected = False
self.authorized = False
self.auth = auth
self.log_auth_fail = True
self.last_heartbeat = clore_utils.unix_timestamp()
self.containers={}
self.containers_set=False
self.pull_logs={}
self.pull_logs_last_fingerprints={}
self.to_stream={}
self.log_message_broker=log_message_broker
self.allowed_log_container_names = []
self.current_container_logs = {}
self.last_bash_rnd = ''
self.oc_enabled = False
self.last_gpu_oc_specs = []
self.last_set_oc = {}
2024-03-21 01:28:02 +00:00
def get_last_heartbeat(self):
return self.last_heartbeat
def get_containers(self):
return self.containers_set, self.containers
def get_oc(self):
return self.oc_enabled, self.last_gpu_oc_specs, self.last_set_oc
2024-03-21 01:28:02 +00:00
def set_ws_peers(self, ws_peers):
tmp_ws_peers=[]
for ws_peer in list(ws_peers.keys()):
if clore_utils.is_valid_websocket_url(ws_peer):
tmp_ws_peers.append(ws_peer)
self.ws_peers = tmp_ws_peers
def set_auth(self, auth):
self.auth=auth
def set_pull_logs(self, pull_logs):
self.pull_logs=pull_logs
async def close_websocket(self, timeout=5):
try:
await asyncio.wait_for(self.connection.close(), timeout)
except asyncio.TimeoutError:
log.debug("close_websocket() | Closing timed out. Forcing close.")
try:
await self.connection.ensure_open() # Force close
except Exception as e:
pass
async def connect(self):
if len(self.ws_peers)>0 and self.auth:
random_ws_peer = random.choice(self.ws_peers)
try:
self.connection = await websockets.connect(random_ws_peer)
self.connected = True
log.debug(f"CLOREWS | Connected to {random_ws_peer}")
await self.send(json.dumps({
"login":str(self.auth),
"type":"python"
}))
except Exception as e:
log.debug(f"CLOREWS | Connection to {random_ws_peer} failed: {e}")
self.connected = False
self.authorized = False
self.pull_logs_last_fingerprints={}
async def send(self, message):
try:
if self.connection and self.connected:
if type(message)==dict:
message=json.dumps(message)
await self.connection.send(message)
log.debug(f"CLOREWS | Message sent: {message}")
return True
else:
return False
except Exception as e:
return False
async def receive(self):
while self.connected:
try:
message = await self.connection.recv()
if message=="NOT_AUTHORIZED" and self.log_auth_fail:
self.log_auth_fail = False
log.error("🔑 Invalid auth key for clore.ai")
elif message=="AUTHORIZED":
self.log_auth_fail = True
self.containers_set = False
self.last_heartbeat = clore_utils.unix_timestamp()
self.authorized=True
log.success("🔑 Authorized with clore.ai")
try:
current_container_logs_keys = self.current_container_logs.keys()
for container_name in current_container_logs_keys:
await self.send(json.dumps({"container_log": self.current_container_logs[container_name], "type":"set", "container_name":container_name}))
except Exception as ei:
pass
elif message=="KEEPALIVE":
self.last_heartbeat = clore_utils.unix_timestamp()
elif message=="NEWER_LOGIN" or message=="WAIT":
await self.close_websocket()
elif message[:10]=="PROVEPULL;":
parts = message.split(';')
if len(parts)==3 and parts[1] in self.to_stream:
current_log = self.to_stream[parts[1]]
current_log_hash = utils.hash_md5(current_log)
if current_log_hash==parts[2]:
del self.to_stream[parts[1]]
else:
try:
parsed_json = json.loads(message)
2024-03-23 00:50:42 +00:00
if "type" in parsed_json and parsed_json["type"]=="set_containers" and "new_containers" in parsed_json and type(parsed_json["new_containers"])==list:
2024-03-21 01:28:02 +00:00
self.last_heartbeat = clore_utils.unix_timestamp()
container_str = json.dumps({"containers":parsed_json["new_containers"]})
await self.send(container_str)
2024-03-23 00:50:42 +00:00
if len(parsed_json["new_containers"]) > 0: # There should be at least one container
self.containers_set = True
self.containers=parsed_json["new_containers"]
2024-03-21 01:28:02 +00:00
#log.success(container_str)
elif "allow_oc" in parsed_json: # Enable OC
self.oc_enabled=True
2024-03-21 01:28:02 +00:00
await self.send(json.dumps({"allow_oc":True}))
elif "gpu_oc_info" in parsed_json:
self.last_gpu_oc_specs = parsed_json["gpu_oc_info"]
2024-03-21 01:28:02 +00:00
elif "set_oc" in parsed_json: # Set specific OC
self.last_set_oc=parsed_json["set_oc"]
2024-03-21 01:28:02 +00:00
back_oc_str = json.dumps({"current_oc":json.dumps(parsed_json["set_oc"], separators=(',',':'))})
await self.send(back_oc_str)
elif "bash_cmd" in parsed_json and type(parsed_json["bash_cmd"])==str and "bash_rnd" in parsed_json:
await self.send(json.dumps({"bash_rnd":parsed_json["bash_rnd"]}))
if self.last_bash_rnd!=parsed_json["bash_rnd"]:
self.last_bash_rnd=parsed_json["bash_rnd"]
asyncio.create_task(run_command_via_executor(parsed_json["bash_cmd"]))
except Exception as e:
log.error(f"CLOREWS | JSON | {e}")
#log.success(f"Message received: {message}")
# Handle received message
except websockets.exceptions.ConnectionClosed:
log.debug("CLOREWS | Connection closed, attempting to reconnect...")
self.connected = False
self.authorized = False
self.pull_logs_last_fingerprints={}
self.containers_set = False
async def stream_pull_logs(self):
if self.authorized:
#self.pull_logs_last_fingerprints
for image_str in self.pull_logs.keys():
value = self.pull_logs[image_str]
last_hash = self.pull_logs_last_fingerprints[image_str] if image_str in self.pull_logs_last_fingerprints.keys() else ''
if "log" in value:
current_hash = utils.hash_md5(value["log"])
if last_hash != current_hash:
self.pull_logs_last_fingerprints[image_str]=current_hash
self.to_stream[image_str] = value["log"] # This makes sure, that each time it will submit the most recent version
ttl_submited_characters=0
for index, image in enumerate(self.to_stream.keys()):
try:
if index < config.max_pull_logs_per_submit_run["instances"] and ttl_submited_characters <= config.max_pull_logs_per_submit_run["size"]:
submit_log = self.to_stream[image]
to_submit_log = submit_log[-config.max_pull_log_size:]
ttl_submited_characters+=len(to_submit_log)
await self.send({
"pull_log":to_submit_log,
"image":image
})
except Exception as e:
log.error(e)
return True
async def stream_container_logs(self):
got_data=[]
while not self.log_message_broker.empty():
got_data.append(await self.log_message_broker.get())
#print("GOT DATA", got_data)
if len(got_data) > 0:
for data_sample in got_data:
if type(data_sample)==list:
self.allowed_log_container_names=data_sample
elif type(data_sample)==str and '|' in data_sample:
container_name, data = data_sample.split('|',1)
if container_name in self.allowed_log_container_names:
log_container_names = self.current_container_logs.keys()
if data=='I':
if container_name in log_container_names:
del self.current_container_logs[container_name]
else:
log_txt = data[1:]
if container_name in log_container_names:
self.current_container_logs[container_name]+=log_txt
await self.send(json.dumps({"container_log":log_txt, "type":"append", "container_name":container_name}))
else:
self.current_container_logs[container_name]=log_txt
await self.send(json.dumps({"container_log":log_txt, "type":"set", "container_name":container_name}))
if len(self.current_container_logs[container_name]) > config.max_container_log_size:
self.current_container_logs[container_name]=trim_container_log(self.current_container_logs[container_name])
container_log_in_cache_names = self.current_container_logs.keys()
for container_in_cache_name in container_log_in_cache_names:
if not container_in_cache_name in self.allowed_log_container_names:
del self.current_container_logs[container_in_cache_name]
async def ensure_connection(self):
if not self.connected:
await self.connect()
async def run(self):
while True:
await self.connect()
receive_task = asyncio.create_task(self.receive())
await receive_task
log.debug("CLOREWS | Waiting to reconnect WS")
await asyncio.sleep(2)