298 lines
14 KiB
Python
298 lines
14 KiB
Python
from concurrent.futures import ThreadPoolExecutor
|
|
from lib import clore_partner
|
|
import asyncio
|
|
import random
|
|
import websockets
|
|
import json
|
|
|
|
from lib import config as config_module
|
|
config = config_module.config
|
|
|
|
from lib import logging as logging_lib
|
|
from lib import utils
|
|
|
|
from clore_hosting import utils as clore_utils
|
|
log = logging_lib.log
|
|
|
|
async def run_command_via_executor(command):
|
|
loop = asyncio.get_running_loop()
|
|
with ThreadPoolExecutor() as pool:
|
|
await loop.run_in_executor(pool, utils.run_command_v2, command)
|
|
|
|
def trim_container_log(string):
|
|
if len(string) > config.max_container_log_size:
|
|
return string[-config.max_container_log_size:]
|
|
else:
|
|
return string
|
|
|
|
class WebSocketClient:
|
|
def __init__(self, log_message_broker, auth=None):
|
|
self.ws_peers = []
|
|
self.connection = None
|
|
self.connected = False
|
|
self.authorized = False
|
|
self.auth = auth
|
|
self.xfs_state = None
|
|
self.log_auth_fail = True
|
|
self.last_heartbeat = clore_utils.unix_timestamp()
|
|
self.containers={}
|
|
self.containers_set=False
|
|
|
|
self.pull_logs={}
|
|
self.pull_logs_last_fingerprints={}
|
|
|
|
self.to_stream={}
|
|
|
|
self.log_message_broker=log_message_broker
|
|
self.allowed_log_container_names = []
|
|
self.current_container_logs = {}
|
|
|
|
self.last_bash_rnd = ''
|
|
|
|
self.oc_enabled = False
|
|
self.last_gpu_oc_specs = []
|
|
self.last_set_oc = {}
|
|
|
|
self.clore_partner_config = None
|
|
self.forwarding_latency_measurment = None
|
|
|
|
self.gpu_list = []
|
|
self.is_hive = False
|
|
|
|
def set_gpu_list(self, gpu_list):
|
|
self.gpu_list = gpu_list
|
|
|
|
def set_is_hive(self, is_hive):
|
|
self.is_hive = is_hive
|
|
|
|
def get_last_heartbeat(self):
|
|
return self.last_heartbeat
|
|
|
|
def get_containers(self):
|
|
partner_container_config = clore_partner.get_partner_container_config()
|
|
return self.containers_set, ((self.containers + [partner_container_config]) if partner_container_config else self.containers)
|
|
|
|
def get_oc(self):
|
|
return self.oc_enabled, self.last_gpu_oc_specs, self.last_set_oc
|
|
|
|
def get_clore_partner_config(self):
|
|
return self.clore_partner_config
|
|
|
|
async def set_forwarding_latency_measurment(self, forwarding_latency_measurment):
|
|
await self.send(json.dumps(
|
|
{
|
|
"forwarding_latency_measurment": forwarding_latency_measurment
|
|
}
|
|
))
|
|
self.forwarding_latency_measurment = forwarding_latency_measurment
|
|
|
|
def set_ws_peers(self, ws_peers):
|
|
tmp_ws_peers=[]
|
|
for ws_peer in list(ws_peers.keys()):
|
|
if clore_utils.is_valid_websocket_url(ws_peer):
|
|
tmp_ws_peers.append(ws_peer)
|
|
|
|
self.ws_peers = tmp_ws_peers
|
|
|
|
def set_auth(self, auth, xfs_state):
|
|
self.auth=auth
|
|
self.xfs_state=xfs_state
|
|
|
|
def set_pull_logs(self, pull_logs):
|
|
self.pull_logs=pull_logs
|
|
|
|
async def close_websocket(self, timeout=5):
|
|
try:
|
|
await asyncio.wait_for(self.connection.close(), timeout)
|
|
except asyncio.TimeoutError:
|
|
log.debug("close_websocket() | Closing timed out. Forcing close.")
|
|
try:
|
|
await self.connection.ensure_open() # Force close
|
|
except Exception as e:
|
|
pass
|
|
|
|
async def connect(self):
|
|
if len(self.ws_peers)>0 and self.auth:
|
|
random_ws_peer = random.choice(self.ws_peers)
|
|
try:
|
|
self.connection = await websockets.connect(random_ws_peer)
|
|
self.connected = True
|
|
log.debug(f"CLOREWS | Connected to {random_ws_peer} ✅")
|
|
await self.send(json.dumps({
|
|
"login":str(self.auth),
|
|
"xfs_state": self.xfs_state,
|
|
"type":"python",
|
|
"clore_partner_support": True,
|
|
"gpu_list": self.gpu_list,
|
|
"is_hive": self.is_hive
|
|
}))
|
|
except Exception as e:
|
|
log.debug(f"CLOREWS | Connection to {random_ws_peer} failed: {e} ❌")
|
|
self.connected = False
|
|
self.authorized = False
|
|
self.pull_logs_last_fingerprints={}
|
|
|
|
|
|
async def send(self, message):
|
|
try:
|
|
if self.connection and self.connected:
|
|
if type(message)==dict:
|
|
message=json.dumps(message)
|
|
await self.connection.send(message)
|
|
log.debug(f"CLOREWS | Message sent: {message}")
|
|
return True
|
|
else:
|
|
return False
|
|
except Exception as e:
|
|
return False
|
|
|
|
async def receive(self):
|
|
while self.connected:
|
|
try:
|
|
message = await self.connection.recv()
|
|
if message=="NOT_AUTHORIZED" and self.log_auth_fail:
|
|
self.log_auth_fail = False
|
|
log.error("🔑 Invalid auth key for clore.ai")
|
|
elif message=="AUTHORIZED":
|
|
self.log_auth_fail = True
|
|
self.containers_set = False
|
|
self.last_heartbeat = clore_utils.unix_timestamp()
|
|
self.authorized=True
|
|
log.success("🔑 Authorized with clore.ai")
|
|
try:
|
|
current_container_logs_keys = self.current_container_logs.keys()
|
|
for container_name in current_container_logs_keys:
|
|
await self.send(json.dumps({"container_log": self.current_container_logs[container_name], "type":"set", "container_name":container_name}))
|
|
except Exception as ei:
|
|
pass
|
|
elif message=="KEEPALIVE":
|
|
self.last_heartbeat = clore_utils.unix_timestamp()
|
|
try:
|
|
if self.forwarding_latency_measurment:
|
|
await self.send(json.dumps(
|
|
{
|
|
"forwarding_latency_measurment": self.forwarding_latency_measurment
|
|
}
|
|
))
|
|
self.forwarding_latency_measurment = None
|
|
except Exception as e:
|
|
pass
|
|
elif message=="NEWER_LOGIN" or message=="WAIT":
|
|
await self.close_websocket()
|
|
elif message[:10]=="PROVEPULL;":
|
|
parts = message.split(';')
|
|
if len(parts)==3 and parts[1] in self.to_stream:
|
|
current_log = self.to_stream[parts[1]]
|
|
current_log_hash = utils.hash_md5(current_log)
|
|
if current_log_hash==parts[2]:
|
|
del self.to_stream[parts[1]]
|
|
else:
|
|
try:
|
|
parsed_json = json.loads(message)
|
|
if "type" in parsed_json and parsed_json["type"]=="partner_config" and "partner_config" in parsed_json and type(parsed_json["partner_config"])==dict:
|
|
self.clore_partner_config = parsed_json["partner_config"]
|
|
await self.send(json.dumps({"partner_config":parsed_json["partner_config"]}))
|
|
elif "type" in parsed_json and parsed_json["type"]=="set_containers" and "new_containers" in parsed_json and type(parsed_json["new_containers"])==list:
|
|
self.last_heartbeat = clore_utils.unix_timestamp()
|
|
container_str = json.dumps({"containers":parsed_json["new_containers"]})
|
|
await self.send(container_str)
|
|
if len(parsed_json["new_containers"]) > 0: # There should be at least one container
|
|
self.containers_set = True
|
|
self.containers=clore_partner.filter_partner_dummy_workload_container(parsed_json["new_containers"])
|
|
#log.success(container_str)
|
|
elif "allow_oc" in parsed_json: # Enable OC
|
|
self.oc_enabled=True
|
|
await self.send(json.dumps({"allow_oc":True}))
|
|
elif "gpu_oc_info" in parsed_json:
|
|
self.last_gpu_oc_specs = parsed_json["gpu_oc_info"]
|
|
elif "set_oc" in parsed_json: # Set specific OC
|
|
self.last_set_oc=parsed_json["set_oc"]
|
|
back_oc_str = json.dumps({"current_oc":json.dumps(parsed_json["set_oc"], separators=(',',':'))})
|
|
await self.send(back_oc_str)
|
|
elif "bash_cmd" in parsed_json and type(parsed_json["bash_cmd"])==str and "bash_rnd" in parsed_json:
|
|
await self.send(json.dumps({"bash_rnd":parsed_json["bash_rnd"]}))
|
|
if self.last_bash_rnd!=parsed_json["bash_rnd"]:
|
|
self.last_bash_rnd=parsed_json["bash_rnd"]
|
|
asyncio.create_task(run_command_via_executor(parsed_json["bash_cmd"]))
|
|
|
|
except Exception as e:
|
|
log.error(f"CLOREWS | JSON | {e}")
|
|
#log.success(f"Message received: {message}")
|
|
# Handle received message
|
|
except websockets.exceptions.ConnectionClosed:
|
|
log.debug("CLOREWS | Connection closed, attempting to reconnect...")
|
|
self.connected = False
|
|
self.authorized = False
|
|
self.pull_logs_last_fingerprints={}
|
|
self.containers_set = False
|
|
|
|
async def stream_pull_logs(self):
|
|
if self.authorized:
|
|
#self.pull_logs_last_fingerprints
|
|
for image_str in self.pull_logs.keys():
|
|
value = self.pull_logs[image_str]
|
|
last_hash = self.pull_logs_last_fingerprints[image_str] if image_str in self.pull_logs_last_fingerprints.keys() else ''
|
|
if "log" in value:
|
|
current_hash = utils.hash_md5(value["log"])
|
|
if last_hash != current_hash:
|
|
self.pull_logs_last_fingerprints[image_str]=current_hash
|
|
self.to_stream[image_str] = value["log"] # This makes sure, that each time it will submit the most recent version
|
|
|
|
ttl_submited_characters=0
|
|
for index, image in enumerate(self.to_stream.keys()):
|
|
try:
|
|
if index < config.max_pull_logs_per_submit_run["instances"] and ttl_submited_characters <= config.max_pull_logs_per_submit_run["size"]:
|
|
submit_log = self.to_stream[image]
|
|
to_submit_log = submit_log[-config.max_pull_log_size:]
|
|
ttl_submited_characters+=len(to_submit_log)
|
|
await self.send({
|
|
"pull_log":to_submit_log,
|
|
"image":image
|
|
})
|
|
except Exception as e:
|
|
log.error(e)
|
|
return True
|
|
|
|
async def stream_container_logs(self):
|
|
got_data=[]
|
|
while not self.log_message_broker.empty():
|
|
got_data.append(await self.log_message_broker.get())
|
|
#print("GOT DATA", got_data)
|
|
if len(got_data) > 0:
|
|
for data_sample in got_data:
|
|
if type(data_sample)==list:
|
|
self.allowed_log_container_names=data_sample
|
|
elif type(data_sample)==str and '|' in data_sample:
|
|
container_name, data = data_sample.split('|',1)
|
|
if container_name in self.allowed_log_container_names:
|
|
log_container_names = self.current_container_logs.keys()
|
|
if data=='I':
|
|
if container_name in log_container_names:
|
|
del self.current_container_logs[container_name]
|
|
else:
|
|
log_txt = data[1:]
|
|
if container_name in log_container_names:
|
|
self.current_container_logs[container_name]+=log_txt
|
|
await self.send(json.dumps({"container_log":log_txt, "type":"append", "container_name":container_name}))
|
|
else:
|
|
self.current_container_logs[container_name]=log_txt
|
|
await self.send(json.dumps({"container_log":log_txt, "type":"set", "container_name":container_name}))
|
|
if len(self.current_container_logs[container_name]) > config.max_container_log_size:
|
|
self.current_container_logs[container_name]=trim_container_log(self.current_container_logs[container_name])
|
|
container_log_in_cache_names = self.current_container_logs.keys()
|
|
for container_in_cache_name in container_log_in_cache_names:
|
|
if not container_in_cache_name in self.allowed_log_container_names:
|
|
del self.current_container_logs[container_in_cache_name]
|
|
|
|
async def ensure_connection(self):
|
|
if not self.connected:
|
|
await self.connect()
|
|
|
|
async def run(self):
|
|
while True:
|
|
await self.connect()
|
|
receive_task = asyncio.create_task(self.receive())
|
|
await receive_task
|
|
log.debug("CLOREWS | Waiting to reconnect WS")
|
|
await asyncio.sleep(2)
|