onboarding/clore_onboarding.py

515 lines
22 KiB
Python

import http.client
import datetime
import argparse
import json
import specs
import base64
import time
import math
import sys
import re
import os
import socket
import asyncio
from urllib.parse import urlparse
import subprocess
from functools import partial
class logger:
RED = '\033[91m'
GREEN = '\033[92m'
BLUE = '\033[94m'
RESET = '\033[0m'
@staticmethod
def _get_current_time():
return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
@staticmethod
def error(message):
print(f"{logger.RED}{logger._get_current_time()} | ERROR | {message}{logger.RESET}")
@staticmethod
def success(message):
print(f"{logger.GREEN}{logger._get_current_time()} | SUCCESS | {message}{logger.RESET}")
@staticmethod
def info(message):
print(f"{logger.BLUE}{logger._get_current_time()} | INFO | {message}{logger.RESET}")
if os.geteuid() != 0:
logger.error("This script must be run as root!")
sys.exit(1)
parser = argparse.ArgumentParser(description="Script with --clore-endpoint flag.")
parser.add_argument('--clore-endpoint', type=str, default="https://api.clore.ai/machine_onboarding", help='Specify the Clore API endpoint. Default is "https://api.clore.ai/machine_onboarding".')
parser.add_argument('--mock', action='store_true', help='Report predefined machine specs (testing only)')
parser.add_argument('--mode', type=str, default="linux")
parser.add_argument('--write-linux-config', type=str, default="")
parser.add_argument('--linux-hostname-override', default="")
parser.add_argument('--auth-file', type=str, default="/opt/clore-hosting/client/auth", help='Auth file location')
args = parser.parse_args()
hive_hive_wallet_conf_path = "mock/wallet.conf" if args.mock else "/hive-config/wallet.conf"
hive_rig_conf_path = "mock/rig.conf" if args.mock else "/hive-config/rig.conf"
hive_oc_conf_path = "mock/oc.conf" if args.mock else "/hive-config/nvidia-oc.conf"
clore_conf_path = "mock/onboarding.json" if args.mock else "/opt/clore-hosting/onboarding.json"
async def run_command(command: str):
loop = asyncio.get_running_loop()
bash_command = ['/bin/bash', '-c', command]
run_subprocess = partial(subprocess.run, bash_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE,text=True)
process = await loop.run_in_executor(None, run_subprocess)
return process.stdout, process.stderr
def clean_config_value(value):
if '#' in value:
value = value.split('#', 1)[0]
value = value.strip()
if (len(value) >= 2 and value[-1] == "'" and value[0] == "'") or (len(value) >= 2 and value[-1] == '"' and value[0] == '"'):
value = value[1:-1]
return value
def filter_name(name):
return re.sub(r'[^A-Za-z0-9_-]', '', name)
def validate_clore_config(clore_config):
def is_valid_hostname_override(value):
return isinstance(value, str) and 1 <= len(value) <= 64 and re.match(r'^[A-Za-z0-9_-]+$', value)
def is_valid_auth(value):
return isinstance(value, str) and len(value) <= 256
def is_valid_multipliers(value):
return isinstance(value, dict) and "on_demand_multiplier" in value and "spot_multiplier" in value and \
1 <= value["on_demand_multiplier"] <= 50 and 1 <= value["spot_multiplier"] <= 50
def is_valid_mrl(value):
return isinstance(value, int) and 6 <= value <= 1440
def is_valid_keep_params(value):
return isinstance(value, bool)
def is_valid_pricing(clore_config):
required_keys = {"on_demand_bitcoin", "on_demand_clore", "spot_bitcoin", "spot_clore"}
if required_keys.issubset(clore_config):
return 0.000001 <= clore_config["on_demand_bitcoin"] <= 0.005 and \
0.1 <= clore_config["on_demand_clore"] <= 5000 and \
0.000001 <= clore_config["spot_bitcoin"] <= 0.005 and \
0.1 <= clore_config["spot_clore"] <= 5000
return False
def is_valid_usd_pricing(autoprice):
required_keys = {"on_demand", "spot"}
if required_keys.issubset(autoprice):
return 0.1 <= autoprice["spot"] <= 1000 and \
0.1 <= autoprice["on_demand"] <= 1000
return False
errors = []
if "hostname_override" in clore_config and not is_valid_hostname_override(clore_config["hostname_override"]):
errors.append("hostname_override must be a string between 1-64 characters, only A-Za-z0-9_- allowed")
if "auth" not in clore_config or not is_valid_auth(clore_config["auth"]):
errors.append("auth is mandatory and must be a string of max 256 character")
if "autoprice" in clore_config and isinstance(clore_config["autoprice"], dict):
if clore_config["autoprice"].get("usd"):
if not is_valid_usd_pricing(clore_config["autoprice"]):
errors.append("usd pricing input is invalid")
elif not is_valid_multipliers(clore_config["autoprice"]):
errors.append("multipliers are not following spec")
if "mrl" not in clore_config or not is_valid_mrl(clore_config["mrl"]):
errors.append("mrl is mandatory and must be an integer in range 6-1440")
if "keep_params" in clore_config and not is_valid_keep_params(clore_config["keep_params"]):
errors.append("keep_params must be a boolean value")
crypto_keys = {"on_demand_bitcoin", "on_demand_clore", "spot_bitcoin", "spot_clore"}
if any(key in clore_config for key in crypto_keys):
if not is_valid_pricing(clore_config):
errors.append("All pricing fields (on_demand_bitcoin, on_demand_clore, spot_bitcoin, spot_clore) must be specified and valid.")
return errors if errors else "Validation successful"
def base64_string_to_json(base64_string):
try:
padding_needed = len(base64_string) % 4
if padding_needed:
base64_string += '=' * (4 - padding_needed)
json_bytes = base64.b64decode(base64_string)
json_str = json_bytes.decode('utf-8')
json_obj = json.loads(json_str)
return json_obj
except Exception as e:
return None
def get_default_power_limits():
try:
cmd = "nvidia-smi -q -d POWER | grep \"Default Power Limit\" | awk '{print $5}'"
result = subprocess.run(cmd, shell=True, check=True, capture_output=True, text=True)
lines = result.stdout.strip().split('\n')
power_limits = []
for line in lines:
if line.lower() != 'n/a':
try:
power_limits.append(int(float(line)))
except ValueError:
continue
return power_limits if power_limits else None
except subprocess.CalledProcessError:
return None
except Exception:
return None
def validate_and_convert(input_str, min_val, max_val, adjust_bounds=False):
try:
int_list = [int(x) for x in input_str.split()]
if adjust_bounds:
int_list = [max(min(num, max_val), min_val) for num in int_list]
else:
if not all(min_val <= num <= max_val for num in int_list):
return None
return int_list
except Exception:
return None
def get_number_or_last(numbers, index):
if index < len(numbers):
return numbers[index]
else:
return numbers[-1]
async def async_read_file(path):
try:
with open(path, 'r') as file:
return file.read()
except Exception as e:
#print(f"Error reading file {path}: {e}")
return None
def extract_last_setcore_setmem(input_string):
try:
setcore_pattern = r'--setcore\s+((?:\d+\s*)+)(?=\D|$)'
setmem_pattern = r'--setmem\s+((?:\d+\s*)+)(?=\D|$)'
setcore_matches = re.findall(setcore_pattern, input_string)
setmem_matches = re.findall(setmem_pattern, input_string)
last_setcore = [int(num) for num in setcore_matches[-1].split()] if setcore_matches else []
last_setmem = [int(num) for num in setmem_matches[-1].split()] if setmem_matches else []
return last_setcore, last_setmem
except Exception as e:
return [[], []]
async def hive_parse_config(file_content):
conf = {}
if file_content:
for line in file_content.split('\n'):
line = line.strip()
if line[:1] != "#" and '=' in line:
key, value = [line.split('=', 1)[0], clean_config_value(line.split('=', 1)[1])]
conf[key] = value
return conf
async def hive_load_configs(default_power_limits, static_config):
parsed_static_config = None
try:
parsed_static_config = json.loads(static_config)
except Exception:
pass
try:
# Non-blocking file reads
wallet_conf_content = await async_read_file(hive_hive_wallet_conf_path)
rig_conf_content = await async_read_file(hive_rig_conf_path)
# Parse rig config
rig_conf = await hive_parse_config(rig_conf_content)
if not rig_conf or "WORKER_NAME" not in rig_conf or "RIG_ID" not in rig_conf:
print("WORKER_NAME or RIG_ID is missing from rig config")
os._exit(1)
clore_miner_present = False
# Parse wallet config
clore_config = None
get_oc_config = False
fs_mem_lock = []
fs_core_lock = []
if wallet_conf_content:
fs_core_lock, fs_mem_lock = extract_last_setcore_setmem(wallet_conf_content)
for wallet_conf_line in wallet_conf_content.split('\n'):
wallet_conf_line = wallet_conf_line.strip()
if wallet_conf_line[:1] != "#" and '=' in wallet_conf_line:
key, value = [wallet_conf_line.split('=', 1)[0], clean_config_value(wallet_conf_line.split('=', 1)[1])]
if key[-9:] == "_TEMPLATE":
possible_clore_config = base64_string_to_json(value)
if possible_clore_config:
clore_config = possible_clore_config
elif key == "CUSTOM_MINER" and value == "clore":
clore_miner_present = True
if (not clore_miner_present or not clore_config) and parsed_static_config:
clore_miner_present = True
clore_config = parsed_static_config
try:
if clore_config and "set_stock_oc" in clore_config:
get_oc_config=True
except Exception as es:
pass
if not clore_miner_present:
logger.info("CLORE not found in flighsheet, exiting")
await run_command("systemctl disable clore-hosting.service ; systemctl stop clore-hosting.service ; systemctl disable docker ; systemctl stop docker ; systemctl disable clore-onboarding.service ; systemctl stop clore-onboarding.service")
sys.exit(0)
out_oc_config = {}
if get_oc_config and default_power_limits:
nvidia_oc = await async_read_file(hive_oc_conf_path)
gpu_cnt = len(default_power_limits)
if nvidia_oc and gpu_cnt > 0:
try:
core_offset = None
mem_offset = None
core_lock = None
mem_lock = None
pl_static = None
for nvidia_conf_line in nvidia_oc.split('\n'):
nvidia_conf_line = nvidia_conf_line.strip()
if nvidia_conf_line[:1] != "#" and '=' in nvidia_conf_line:
key, value = [nvidia_conf_line.split('=', 1)[0], clean_config_value(nvidia_conf_line.split('=', 1)[1])]
if value == "":
pass
elif key=="CLOCK":
core_offset = [math.floor(num / 2) for num in validate_and_convert(value, -2000, 2000, adjust_bounds=True)]
elif key=="MEM":
mem_offset = [math.floor(num / 2) for num in validate_and_convert(value, -2000, 6000, adjust_bounds=True)]
elif key=="PLIMIT":
pl_static = validate_and_convert(value, 1, 1500, adjust_bounds=True)
elif key=="LCLOCK":
core_lock = validate_and_convert(value, 0, 12000, adjust_bounds=True)
elif key=="LMEM":
mem_lock = validate_and_convert(value, 0, 32000, adjust_bounds=True)
if len(fs_mem_lock)>0:
mem_lock = fs_mem_lock
if len(fs_core_lock)>0:
core_lock = fs_core_lock
#print(mem_lock, core_lock)
if core_offset or mem_offset or pl_static or mem_lock or core_lock:
for gpu_idx, default_pl in enumerate(default_power_limits):
out_oc_config[str(gpu_idx)] = {
"core": get_number_or_last(core_offset, gpu_idx) if core_offset else 0,
"mem": get_number_or_last(mem_offset, gpu_idx) if mem_offset else 0,
"pl": get_number_or_last(pl_static, gpu_idx) if pl_static else default_pl
}
if type(core_lock)==list and len(core_lock)>0:
out_oc_config[str(gpu_idx)]["core_lock"] = get_number_or_last(core_lock, gpu_idx)
if type(mem_lock)==list and len(mem_lock)>0:
out_oc_config[str(gpu_idx)]["mem_lock"] = get_number_or_last(mem_lock, gpu_idx)
except Exception as oc_info_e:
pass
# Construct machine name
machine_name = f"{filter_name(rig_conf['WORKER_NAME'][:32])}_HIVE_{rig_conf['RIG_ID']}"
return machine_name, clore_config, out_oc_config
except Exception as e:
logger.error(f"Can't load rig.conf, wallet.conf | {e}")
async def post_request(url, body, headers=None, timeout=15):
parsed_url = urlparse(url)
if parsed_url.scheme == 'https':
conn = http.client.HTTPSConnection(parsed_url.hostname, parsed_url.port or 443, timeout=timeout)
elif parsed_url.scheme == 'http':
conn = http.client.HTTPConnection(parsed_url.hostname, parsed_url.port or 80, timeout=timeout)
else:
raise ValueError(f"Unsupported URL scheme: {parsed_url.scheme}")
json_data = json.dumps(body)
if headers is None:
headers = {}
headers['Content-Type'] = 'application/json'
path = parsed_url.path
if parsed_url.query:
path += '?' + parsed_url.query
try:
conn.request("POST", path, body=json_data, headers=headers)
response = conn.getresponse()
response_data = response.read().decode('utf-8')
try:
parsed_response_data = json.loads(response_data)
response_data = parsed_response_data
except Exception as ep:
pass
status_code = response.status
#if 200 <= status_code < 300:
# print(f"Request was successful: {status_code}")
#else:
# print(f"Non-standard status code received: {status_code}")
return status_code, response_data
except (http.client.HTTPException, TimeoutError) as e:
print(f"Request failed: {e}")
return None, None
finally:
conn.close()
def clean_clore_config(clore_config):
clore_config_copy = clore_config.copy()
if "auth" in clore_config_copy:
if "auth" in clore_config_copy:
del clore_config_copy["auth"]
if "hostname_override" in clore_config_copy:
del clore_config_copy["hostname_override"]
if "set_stock_oc" in clore_config_copy:
del clore_config_copy["set_stock_oc"]
if "save_config" in clore_config_copy:
del clore_config_copy["save_config"]
return clore_config_copy
def verify_or_update_file(file_path: str, expected_content: str) -> bool:
try:
if os.path.exists(file_path):
with open(file_path, 'r', encoding='utf-8') as file:
current_content = file.read()
if current_content == expected_content:
return True
with open(file_path, 'w', encoding='utf-8') as file:
file.write(expected_content)
return False
except Exception:
return True
def get_machine_id():
machine_id_path = "/etc/machine-id"
if os.path.isfile(machine_id_path):
with open(machine_id_path, "r") as file:
return file.read().strip()
return None
next_retry_reached_server_limit = 0
if args.write_linux_config:
linux_config = base64_string_to_json(args.write_linux_config)
if linux_config:
if args.linux_hostname_override:
if 1 <= len(args.linux_hostname_override) <= 64 and re.match(r'^[A-Za-z0-9_-]+$', args.linux_hostname_override):
linux_config["hostname_override"] = args.linux_hostname_override
else:
logger.error("Input hostname not valid")
sys.exit(1)
verify_or_update_file(clore_conf_path, json.dumps(linux_config))
logger.success("Config written")
sys.exit(0)
else:
logger.error("Invalid config")
sys.exit(1)
async def main(machine_specs):
global next_retry_reached_server_limit
last_used_config = None
ever_pending_creation = False
machine_id = get_machine_id()
default_power_limits = get_default_power_limits()
if not machine_id:
logger.error("Can't load machine ID")
sys.exit(1)
if not default_power_limits or len(default_power_limits)==0:
logger.error("Can't load default power limits of nVidia GPU(s)")
sys.exit(1)
oc_config = {}
while True:
try:
if args.mode == "linux":
clore_config = await async_read_file(clore_conf_path)
clore_config = json.loads(clore_config)
machine_name = f"LINUX_{machine_id}"
elif args.mode == "hive":
static_clore_config = await async_read_file(clore_conf_path)
machine_name, clore_config, oc_config = await hive_load_configs(default_power_limits, static_clore_config)
#print(f"Machine Name: {machine_name}")
config_validation = validate_clore_config(clore_config)
if config_validation == "Validation successful":
if "save_config" in clore_config and args.mode == "hive":
verify_or_update_file(clore_conf_path, json.dumps(clore_config))
if "set_stock_oc" in clore_config:
if oc_config == {}:
clore_config["clear_oc_override"] = True
else:
clore_config["stock_oc_override"] = oc_config
if clore_config != last_used_config or (time.time() > next_retry_reached_server_limit and next_retry_reached_server_limit > 0):
last_used_config = clore_config.copy()
if type(clore_config) == dict and "hostname_override" in clore_config:
machine_name = clore_config["hostname_override"]
clore_config["name"]=machine_name
clore_config["specs"]=machine_specs
status_code, response_data = await post_request(
args.clore_endpoint,
clean_clore_config(clore_config),
{"auth": clore_config["auth"]},
15
)
next_retry_reached_server_limit = 0
if type(response_data) == dict: # Response data seem to be correct format
if response_data.get("status") == "invalid_auth":
logger.error("Invalid auth token")
elif response_data.get("error") == "exceeded_rate_limit":
logger.error("Exceeded API limits, probably a lot of servers on same network")
logger.info("Retrying request in 65s")
await asyncio.sleep(60)
last_used_config = None
elif response_data.get("status") == "exceeded_limit":
logger.error("Your account already has the maximal server limit, retrying in 12hr")
next_retry_reached_server_limit = time.time() + 60*60*12
elif response_data.get("status") == "creation_pending":
logger.info("Machine creation is pending on clore.ai side")
await asyncio.sleep(60 if ever_pending_creation else 10)
ever_pending_creation = True
last_used_config = None
elif "init_communication_token" in response_data and "private_communication_token":
clore_hosting_sw_auth_str = f"{response_data['init_communication_token']}:{response_data['private_communication_token']}"
was_ok = verify_or_update_file(args.auth_file, clore_hosting_sw_auth_str)
if was_ok:
logger.info("Token for hosting software already configured")
await run_command("systemctl start clore-hosting.service")
else:
logger.success("Updated local auth file, restarting clore hosting")
await run_command("systemctl restart clore-hosting.service")
else:
logger.error("Unknown API response, retrying in 65s")
await asyncio.sleep(60)
last_used_config = None
else:
logger.error(f"Could not parse config - {' | '.join(config_validation)}")
except Exception as e:
print(e)
await asyncio.sleep(5)
if __name__ == "__main__":
machine_specs = specs.get(benchmark_disk=True, mock=args.mock)
asyncio.run(main(machine_specs))