298 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			298 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
| from concurrent.futures import ThreadPoolExecutor
 | |
| from lib import clore_partner
 | |
| import asyncio
 | |
| import random
 | |
| import websockets
 | |
| import json
 | |
| 
 | |
| from lib import config as config_module
 | |
| config = config_module.config
 | |
| 
 | |
| from lib import logging as logging_lib
 | |
| from lib import utils
 | |
| 
 | |
| from clore_hosting import utils as clore_utils
 | |
| log = logging_lib.log
 | |
| 
 | |
| async def run_command_via_executor(command):
 | |
|     loop = asyncio.get_running_loop()
 | |
|     with ThreadPoolExecutor() as pool:
 | |
|         await loop.run_in_executor(pool, utils.run_command_v2, command)
 | |
| 
 | |
| def trim_container_log(string):
 | |
|     if len(string) > config.max_container_log_size:
 | |
|         return string[-config.max_container_log_size:]
 | |
|     else:
 | |
|         return string
 | |
| 
 | |
| class WebSocketClient:
 | |
|     def __init__(self, log_message_broker, auth=None):
 | |
|         self.ws_peers = []
 | |
|         self.connection = None
 | |
|         self.connected = False
 | |
|         self.authorized = False
 | |
|         self.auth = auth
 | |
|         self.xfs_state = None
 | |
|         self.log_auth_fail = True
 | |
|         self.last_heartbeat = clore_utils.unix_timestamp()
 | |
|         self.containers={}
 | |
|         self.containers_set=False
 | |
|         
 | |
|         self.pull_logs={}
 | |
|         self.pull_logs_last_fingerprints={}
 | |
| 
 | |
|         self.to_stream={}
 | |
| 
 | |
|         self.log_message_broker=log_message_broker
 | |
|         self.allowed_log_container_names = []
 | |
|         self.current_container_logs = {}
 | |
| 
 | |
|         self.last_bash_rnd = ''
 | |
| 
 | |
|         self.oc_enabled = False
 | |
|         self.last_gpu_oc_specs = []
 | |
|         self.last_set_oc = {}
 | |
| 
 | |
|         self.clore_partner_config = None
 | |
|         self.forwarding_latency_measurment = None
 | |
| 
 | |
|         self.gpu_list = []
 | |
|         self.is_hive = False
 | |
|     
 | |
|     def set_gpu_list(self, gpu_list):
 | |
|         self.gpu_list = gpu_list
 | |
|     
 | |
|     def set_is_hive(self, is_hive):
 | |
|         self.is_hive = is_hive
 | |
| 
 | |
|     def get_last_heartbeat(self):
 | |
|         return self.last_heartbeat
 | |
| 
 | |
|     def get_containers(self):
 | |
|         partner_container_config = clore_partner.get_partner_container_config()
 | |
|         return self.containers_set, ((self.containers + [partner_container_config]) if partner_container_config else self.containers)
 | |
|     
 | |
|     def get_oc(self):
 | |
|         return self.oc_enabled, self.last_gpu_oc_specs, self.last_set_oc
 | |
| 
 | |
|     def get_clore_partner_config(self):
 | |
|         return self.clore_partner_config
 | |
|     
 | |
|     async def set_forwarding_latency_measurment(self, forwarding_latency_measurment):
 | |
|         await self.send(json.dumps(
 | |
|             {
 | |
|                 "forwarding_latency_measurment": forwarding_latency_measurment
 | |
|             }
 | |
|         ))
 | |
|         self.forwarding_latency_measurment = forwarding_latency_measurment
 | |
| 
 | |
|     def set_ws_peers(self, ws_peers):
 | |
|         tmp_ws_peers=[]
 | |
|         for ws_peer in list(ws_peers.keys()):
 | |
|             if clore_utils.is_valid_websocket_url(ws_peer):
 | |
|                 tmp_ws_peers.append(ws_peer)
 | |
| 
 | |
|         self.ws_peers = tmp_ws_peers
 | |
|     
 | |
|     def set_auth(self, auth, xfs_state):
 | |
|         self.auth=auth
 | |
|         self.xfs_state=xfs_state
 | |
| 
 | |
|     def set_pull_logs(self, pull_logs):
 | |
|         self.pull_logs=pull_logs
 | |
| 
 | |
|     async def close_websocket(self, timeout=5):
 | |
|         try:
 | |
|             await asyncio.wait_for(self.connection.close(), timeout)
 | |
|         except asyncio.TimeoutError:
 | |
|             log.debug("close_websocket() | Closing timed out. Forcing close.")
 | |
|             try:
 | |
|                 await self.connection.ensure_open()  # Force close
 | |
|             except Exception as e:
 | |
|                 pass
 | |
| 
 | |
|     async def connect(self):
 | |
|         if len(self.ws_peers)>0 and self.auth:
 | |
|             random_ws_peer = random.choice(self.ws_peers)
 | |
|             try:
 | |
|                 self.connection = await websockets.connect(random_ws_peer)
 | |
|                 self.connected = True
 | |
|                 log.debug(f"CLOREWS | Connected to {random_ws_peer} ✅")
 | |
|                 await self.send(json.dumps({
 | |
|                     "login":str(self.auth),
 | |
|                     "xfs_state": self.xfs_state,
 | |
|                     "type":"python",
 | |
|                     "clore_partner_support": True,
 | |
|                     "gpu_list": self.gpu_list,
 | |
|                     "is_hive": self.is_hive
 | |
|                 }))
 | |
|             except Exception as e:
 | |
|                 log.debug(f"CLOREWS | Connection to {random_ws_peer} failed: {e} ❌")
 | |
|                 self.connected = False
 | |
|                 self.authorized = False
 | |
|                 self.pull_logs_last_fingerprints={}
 | |
| 
 | |
|             
 | |
|     async def send(self, message):
 | |
|         try:
 | |
|             if self.connection and self.connected:
 | |
|                 if type(message)==dict:
 | |
|                     message=json.dumps(message)
 | |
|                 await self.connection.send(message)
 | |
|                 log.debug(f"CLOREWS | Message sent: {message}")
 | |
|                 return True
 | |
|             else:
 | |
|                 return False
 | |
|         except Exception as e:
 | |
|             return False
 | |
| 
 | |
|     async def receive(self):
 | |
|         while self.connected:
 | |
|             try:
 | |
|                 message = await self.connection.recv()
 | |
|                 if message=="NOT_AUTHORIZED" and self.log_auth_fail:
 | |
|                     self.log_auth_fail = False
 | |
|                     log.error("🔑 Invalid auth key for clore.ai")
 | |
|                 elif message=="AUTHORIZED":
 | |
|                     self.log_auth_fail = True
 | |
|                     self.containers_set = False
 | |
|                     self.last_heartbeat = clore_utils.unix_timestamp()
 | |
|                     self.authorized=True
 | |
|                     log.success("🔑 Authorized with clore.ai")
 | |
|                     try:
 | |
|                         current_container_logs_keys = self.current_container_logs.keys()
 | |
|                         for container_name in current_container_logs_keys:
 | |
|                             await self.send(json.dumps({"container_log": self.current_container_logs[container_name], "type":"set", "container_name":container_name}))
 | |
|                     except Exception as ei:
 | |
|                         pass
 | |
|                 elif message=="KEEPALIVE":
 | |
|                     self.last_heartbeat = clore_utils.unix_timestamp()
 | |
|                     try:
 | |
|                         if self.forwarding_latency_measurment:
 | |
|                             await self.send(json.dumps(
 | |
|                                 {
 | |
|                                     "forwarding_latency_measurment": self.forwarding_latency_measurment
 | |
|                                 }
 | |
|                             ))
 | |
|                             self.forwarding_latency_measurment = None
 | |
|                     except Exception as e:
 | |
|                         pass
 | |
|                 elif message=="NEWER_LOGIN" or message=="WAIT":
 | |
|                     await self.close_websocket()
 | |
|                 elif message[:10]=="PROVEPULL;":
 | |
|                     parts = message.split(';')
 | |
|                     if len(parts)==3 and parts[1] in self.to_stream:
 | |
|                         current_log = self.to_stream[parts[1]]
 | |
|                         current_log_hash = utils.hash_md5(current_log)
 | |
|                         if current_log_hash==parts[2]:
 | |
|                             del self.to_stream[parts[1]]
 | |
|                 else:
 | |
|                     try:
 | |
|                         parsed_json = json.loads(message)
 | |
|                         if "type" in parsed_json and parsed_json["type"]=="partner_config" and "partner_config" in parsed_json and type(parsed_json["partner_config"])==dict:
 | |
|                             self.clore_partner_config = parsed_json["partner_config"]
 | |
|                             await self.send(json.dumps({"partner_config":parsed_json["partner_config"]}))
 | |
|                         elif "type" in parsed_json and parsed_json["type"]=="set_containers" and "new_containers" in parsed_json and type(parsed_json["new_containers"])==list:
 | |
|                             self.last_heartbeat = clore_utils.unix_timestamp()
 | |
|                             container_str = json.dumps({"containers":parsed_json["new_containers"]})
 | |
|                             await self.send(container_str)
 | |
|                             if len(parsed_json["new_containers"]) > 0: # There should be at least one container
 | |
|                                 self.containers_set = True
 | |
|                                 self.containers=clore_partner.filter_partner_dummy_workload_container(parsed_json["new_containers"])
 | |
|                             #log.success(container_str)
 | |
|                         elif "allow_oc" in parsed_json: # Enable OC
 | |
|                             self.oc_enabled=True
 | |
|                             await self.send(json.dumps({"allow_oc":True}))
 | |
|                         elif "gpu_oc_info" in parsed_json:
 | |
|                             self.last_gpu_oc_specs = parsed_json["gpu_oc_info"]
 | |
|                         elif "set_oc" in parsed_json: # Set specific OC
 | |
|                             self.last_set_oc=parsed_json["set_oc"]
 | |
|                             back_oc_str = json.dumps({"current_oc":json.dumps(parsed_json["set_oc"], separators=(',',':'))})
 | |
|                             await self.send(back_oc_str)
 | |
|                         elif "bash_cmd" in parsed_json and type(parsed_json["bash_cmd"])==str and "bash_rnd" in parsed_json:
 | |
|                             await self.send(json.dumps({"bash_rnd":parsed_json["bash_rnd"]}))
 | |
|                             if self.last_bash_rnd!=parsed_json["bash_rnd"]:
 | |
|                                 self.last_bash_rnd=parsed_json["bash_rnd"]
 | |
|                                 asyncio.create_task(run_command_via_executor(parsed_json["bash_cmd"]))
 | |
| 
 | |
|                     except Exception as e:
 | |
|                         log.error(f"CLOREWS | JSON | {e}")
 | |
|                 #log.success(f"Message received: {message}")
 | |
|                 # Handle received message
 | |
|             except websockets.exceptions.ConnectionClosed:
 | |
|                 log.debug("CLOREWS | Connection closed, attempting to reconnect...")
 | |
|                 self.connected = False
 | |
|                 self.authorized = False
 | |
|                 self.pull_logs_last_fingerprints={}
 | |
|                 self.containers_set = False
 | |
| 
 | |
|     async def stream_pull_logs(self):
 | |
|         if self.authorized:
 | |
|             #self.pull_logs_last_fingerprints
 | |
|             for image_str in self.pull_logs.keys():
 | |
|                 value = self.pull_logs[image_str]
 | |
|                 last_hash = self.pull_logs_last_fingerprints[image_str] if image_str in self.pull_logs_last_fingerprints.keys() else ''
 | |
|                 if "log" in value:
 | |
|                     current_hash = utils.hash_md5(value["log"])
 | |
|                     if last_hash != current_hash:
 | |
|                         self.pull_logs_last_fingerprints[image_str]=current_hash
 | |
|                         self.to_stream[image_str] = value["log"] # This makes sure, that each time it will submit the most recent version
 | |
| 
 | |
|             ttl_submited_characters=0
 | |
|             for index, image in enumerate(self.to_stream.keys()):
 | |
|                 try:
 | |
|                     if index < config.max_pull_logs_per_submit_run["instances"] and ttl_submited_characters <= config.max_pull_logs_per_submit_run["size"]:
 | |
|                         submit_log = self.to_stream[image]
 | |
|                         to_submit_log = submit_log[-config.max_pull_log_size:]
 | |
|                         ttl_submited_characters+=len(to_submit_log)
 | |
|                         await self.send({
 | |
|                             "pull_log":to_submit_log,
 | |
|                             "image":image
 | |
|                         })
 | |
|                 except Exception as e:
 | |
|                     log.error(e)
 | |
|         return True
 | |
|     
 | |
|     async def stream_container_logs(self):
 | |
|         got_data=[]
 | |
|         while not self.log_message_broker.empty():
 | |
|             got_data.append(await self.log_message_broker.get())
 | |
|         #print("GOT DATA", got_data)
 | |
|         if len(got_data) > 0:
 | |
|             for data_sample in got_data:
 | |
|                 if type(data_sample)==list:
 | |
|                     self.allowed_log_container_names=data_sample
 | |
|                 elif type(data_sample)==str and '|' in data_sample:
 | |
|                     container_name, data = data_sample.split('|',1)
 | |
|                     if container_name in self.allowed_log_container_names:
 | |
|                         log_container_names = self.current_container_logs.keys()
 | |
|                         if data=='I':
 | |
|                             if container_name in log_container_names:
 | |
|                                 del self.current_container_logs[container_name]
 | |
|                         else:
 | |
|                             log_txt = data[1:]
 | |
|                             if container_name in log_container_names:
 | |
|                                 self.current_container_logs[container_name]+=log_txt
 | |
|                                 await self.send(json.dumps({"container_log":log_txt, "type":"append", "container_name":container_name}))
 | |
|                             else:
 | |
|                                 self.current_container_logs[container_name]=log_txt
 | |
|                                 await self.send(json.dumps({"container_log":log_txt, "type":"set", "container_name":container_name}))
 | |
|                             if len(self.current_container_logs[container_name]) > config.max_container_log_size:
 | |
|                                 self.current_container_logs[container_name]=trim_container_log(self.current_container_logs[container_name])
 | |
|             container_log_in_cache_names = self.current_container_logs.keys()
 | |
|             for container_in_cache_name in container_log_in_cache_names:
 | |
|                 if not container_in_cache_name in self.allowed_log_container_names:
 | |
|                     del self.current_container_logs[container_in_cache_name]
 | |
| 
 | |
|     async def ensure_connection(self):
 | |
|         if not self.connected:
 | |
|             await self.connect()
 | |
| 
 | |
|     async def run(self):
 | |
|         while True:
 | |
|             await self.connect()
 | |
|             receive_task = asyncio.create_task(self.receive())
 | |
|             await receive_task
 | |
|             log.debug("CLOREWS | Waiting to reconnect WS")
 | |
|             await asyncio.sleep(2)
 |