230 lines
8.3 KiB
Python
230 lines
8.3 KiB
Python
from typing import Optional, Tuple, Dict
|
|
from lib import config as config_module
|
|
from lib import logging as logging_lib
|
|
from lib import nvml
|
|
import subprocess
|
|
import hashlib
|
|
import asyncio
|
|
import random
|
|
import string
|
|
import shutil
|
|
import shlex
|
|
import time
|
|
import math
|
|
import json
|
|
import os
|
|
|
|
log = logging_lib.log
|
|
|
|
config = config_module.config
|
|
|
|
def run_command(command):
|
|
"""Utility function to run a shell command and return its output."""
|
|
result = subprocess.run(command, text=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
|
|
|
|
return result.returncode, result.stdout.strip(), result.stderr.strip()
|
|
|
|
def parse_rule_to_dict(rule):
|
|
tokens = shlex.split(rule)
|
|
rule_dict = {}
|
|
i = 0
|
|
while i < len(tokens):
|
|
if tokens[i].startswith("-"):
|
|
# For options without a value, set them to True
|
|
rule_dict[tokens[i]] = tokens[i + 1] if i + 1 < len(tokens) and not tokens[i + 1].startswith("-") else True
|
|
i += 2
|
|
else:
|
|
i += 1
|
|
return rule_dict
|
|
|
|
def normalize_rule(rule_dict):
|
|
# If necessary, convert values to a normalized form here
|
|
# For example, converting IP addresses to a standard format
|
|
# For this example, we'll just sort the dictionary
|
|
normalized = dict(sorted(rule_dict.items()))
|
|
return normalized
|
|
|
|
def get_auth():
|
|
try:
|
|
if 'AUTH_TOKEN' in os.environ:
|
|
return os.environ['AUTH_TOKEN']
|
|
auth_str = ''
|
|
with open(config.auth_file, "r", encoding="utf-8") as file:
|
|
auth_str = file.read().strip()
|
|
return auth_str
|
|
except Exception as e:
|
|
return ''
|
|
|
|
def get_allowed_container_names():
|
|
allowed_container_names = os.getenv("ALLOWED_CONTAINER_NAMES")
|
|
if type(allowed_container_names)==str and len(allowed_container_names)>0:
|
|
return [x for x in allowed_container_names.split(',') if x]
|
|
return []
|
|
|
|
def unix_timestamp():
|
|
return int(time.time())
|
|
|
|
def hash_md5(input_string):
|
|
return hashlib.md5(input_string.encode()).hexdigest()
|
|
|
|
def run_command_v2(command, timeout=900):
|
|
try:
|
|
# Set the timeout to 900 seconds (15 minutes)
|
|
subprocess.run(["bash", "-c", command], check=True, timeout=timeout)
|
|
except subprocess.CalledProcessError as e:
|
|
log.debug(f"run_command_v2() | A subprocess error occurred: {e}")
|
|
except subprocess.TimeoutExpired as e:
|
|
log.debug(f"run_command_v2() | Command timed out: {e}")
|
|
|
|
def yes_no_question(prompt):
|
|
while True:
|
|
response = input(prompt + " (y/n): ").strip().lower()
|
|
if response in {'y', 'yes'}:
|
|
return True
|
|
elif response in {'n', 'no'}:
|
|
return False
|
|
else:
|
|
print("Please enter 'y' or 'n'.")
|
|
|
|
def validate_cuda_version(ver_str):
|
|
if ':' in ver_str:
|
|
pc = ver_str.split(':')
|
|
if pc[0] == "11":
|
|
if int(pc[1]) >= 7:
|
|
return True
|
|
else:
|
|
return False
|
|
elif int(pc[0]) > 11:
|
|
return True
|
|
else:
|
|
return False
|
|
else:
|
|
return False
|
|
def generate_random_string(length):
|
|
characters = string.ascii_letters + string.digits
|
|
return ''.join(random.choice(characters) for _ in range(length))
|
|
|
|
HIVE_PATH="/hive/bin:/hive/sbin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:./"
|
|
|
|
def hive_set_miner_status(enabled=False):
|
|
### control miner state - OFF/ON
|
|
screen_out = run_command("screen -ls")
|
|
miner_screen_running = False
|
|
miner_screen_session_pids = []
|
|
if screen_out[0] == 0 or screen_out[0] == 1:
|
|
screen_lines=screen_out[1].split('\n')
|
|
for screen_line in screen_lines:
|
|
screen_line_parts=screen_line.replace('\t', '', 1).split('\t')
|
|
if len(screen_line_parts)>2 and '.' in screen_line_parts[0]:
|
|
if screen_line_parts[0].split('.',1)[1]=="miner":
|
|
miner_screen_session_pids.append(screen_line_parts[0].split('.',1)[0])
|
|
miner_screen_running=True
|
|
if len(miner_screen_session_pids) > 1: ## Something really bad going on, destroy all instances
|
|
for idx, miner_screen_session_pid in enumerate(miner_screen_session_pids):
|
|
run_command(f"kill -9 {miner_screen_session_pid}{' && screen -wipe' if idx==len(miner_screen_session_pids)-1 else ''}")
|
|
elif miner_screen_running and not enabled:
|
|
run_command(f"/bin/bash -c \"PATH={HIVE_PATH} && sudo /hive/bin/miner stop\"")
|
|
elif enabled and not miner_screen_running:
|
|
run_command(f"/bin/bash -c \"export PATH={HIVE_PATH} && sudo /hive/sbin/nvidia-oc && source ~/.bashrc ; sudo /hive/bin/miner start\"")
|
|
|
|
def get_extra_allowed_images():
|
|
if os.path.exists(config.extra_allowed_images_file):
|
|
try:
|
|
with open(config.extra_allowed_images_file, 'r') as file:
|
|
content = file.read()
|
|
|
|
data = json.loads(content)
|
|
|
|
if isinstance(data, list):
|
|
if all(isinstance(item, dict) and set(item.keys()) == {'repository', 'allowed_tags'} and isinstance(item['repository'], str) and isinstance(item['allowed_tags'], list) and all(isinstance(tag, str) for tag in item['allowed_tags']) for item in data):
|
|
return data
|
|
else:
|
|
return []
|
|
else:
|
|
return []
|
|
except Exception as e:
|
|
log.error(f"get_extra_allowed_images() | error: {e}")
|
|
return []
|
|
else:
|
|
return []
|
|
|
|
async def async_run_command(
|
|
command: str,
|
|
timeout: Optional[float] = None,
|
|
env: Optional[Dict[str, str]] = None
|
|
) -> Tuple[int, str, str]:
|
|
command_env = env if env is not None else {}
|
|
|
|
try:
|
|
proc = await asyncio.create_subprocess_shell(
|
|
command,
|
|
stdout=subprocess.PIPE,
|
|
stderr=subprocess.PIPE,
|
|
env=command_env
|
|
)
|
|
|
|
try:
|
|
stdout, stderr = await asyncio.wait_for(
|
|
proc.communicate(),
|
|
timeout=timeout
|
|
)
|
|
|
|
stdout_str = stdout.decode('utf-8').strip() if stdout else ''
|
|
stderr_str = stderr.decode('utf-8').strip() if stderr else ''
|
|
|
|
return proc.returncode, stdout_str, stderr_str
|
|
|
|
except asyncio.TimeoutError:
|
|
# Handle timeout: terminate the process gracefully first
|
|
proc.terminate()
|
|
try:
|
|
await asyncio.wait_for(proc.wait(), timeout=5) # Wait for it to exit
|
|
except asyncio.TimeoutError:
|
|
# Force kill the process if it doesn't terminate
|
|
proc.kill()
|
|
await proc.wait()
|
|
|
|
return -1, '', f'Command timed out after {timeout} seconds'
|
|
|
|
except Exception as e:
|
|
return -1, '', str(e)
|
|
|
|
def get_free_space_mb(path):
|
|
"""Get free space in MB for the given path."""
|
|
total, used, free = shutil.disk_usage(path)
|
|
return free // (1024 * 1024) # Convert bytes to MB
|
|
|
|
def get_directory_size_mb(path):
|
|
"""Get the size of a directory in MB."""
|
|
total_size = 0
|
|
for dirpath, dirnames, filenames in os.walk(path):
|
|
for f in filenames:
|
|
fp = os.path.join(dirpath, f)
|
|
# Skip if the file doesn't exist (symlinks, etc.)
|
|
if not os.path.islink(fp) and os.path.exists(fp):
|
|
total_size += os.path.getsize(fp)
|
|
return total_size // (1024 * 1024) # Convert bytes to MB
|
|
|
|
class shm_calculator:
|
|
def __init__(self, total_ram):
|
|
self.total_ram = total_ram
|
|
self.gpu_vram_sizes = []
|
|
|
|
def calculate(self, used_gpu_ids):
|
|
assume_ram_utilised = 2500 #MB
|
|
default_shm_size = 64 #MB
|
|
|
|
if len(self.gpu_vram_sizes) == 0:
|
|
self.gpu_vram_sizes = nvml.get_vram_per_gpu()
|
|
|
|
instance_vram_total = 0
|
|
total_vram_size = sum(self.gpu_vram_sizes)
|
|
for idx, value in enumerate(self.gpu_vram_sizes):
|
|
if used_gpu_ids == '*' or idx in used_gpu_ids:
|
|
instance_vram_total += value
|
|
if instance_vram_total == 0 or total_vram_size == 0:
|
|
return default_shm_size
|
|
shm_size = instance_vram_total * 1.5 if instance_vram_total * 1.5 < self.total_ram - assume_ram_utilised else (
|
|
instance_vram_total/total_vram_size * (self.total_ram - assume_ram_utilised)
|
|
)
|
|
return math.floor(shm_size if shm_size > default_shm_size else default_shm_size) |