diff --git a/lib/config.py b/lib/config.py index 101c932..2fae529 100644 --- a/lib/config.py +++ b/lib/config.py @@ -32,7 +32,8 @@ hard_config = { "container_log_streaming_interval": 2, # Seconds "maximum_service_loop_time": 900, # Seconds, failsafe variable - if service is stuck processing longer than this timeframe it will lead into restarting the app "maximum_pull_service_loop_time": 14400, # Exception for image pulling - "creation_engine": "wrapper" # "wrapper" or "sdk" | Wrapper - wrapped docker cli, SDK - docker sdk + "creation_engine": "wrapper", # "wrapper" or "sdk" | Wrapper - wrapped docker cli, SDK - docker sdk + "allow_mixed_gpus": False } parser = argparse.ArgumentParser(description='Example argparse usage') diff --git a/lib/get_specs.py b/lib/get_specs.py index f0fddaf..0014a67 100644 --- a/lib/get_specs.py +++ b/lib/get_specs.py @@ -291,9 +291,17 @@ class Specs: def __init__(self): self.motherboard_name_file = "/sys/devices/virtual/dmi/id/board_name" - async def get(self, benchmark_internet=False, benchmark_disk=False): + async def get(self, benchmark_internet=False, benchmark_disk=False, require_same_gpus=False): total_threads, total_cores, model_name = self.get_cpu_info() gpu_str, gpu_mem, gpus, nvml_err = get_gpu_info() + if require_same_gpus: + last_gpu_name='' + for gpu in gpus: + if not last_gpu_name: + last_gpu_name=gpu["name"] + elif last_gpu_name!=gpu["name"]: + print("\033[31mMixed GPU machines are not allowed\033[0m") + sys.exit(1) docker_daemon_config = docker_interface.get_daemon_config() disk_str="" data_root_location="main_disk" @@ -337,7 +345,10 @@ class Specs: data_root_location="separate" disk_str = f"{disk_type} {overlay_total_size}GB" - disk_speeds = await disk_speed() + if benchmark_disk: + disk_speeds = await disk_speed() + else: + disk_speeds = [0,0] response = { "mb": await self.motherboard_type(), diff --git a/lib/init_server.py b/lib/init_server.py index bfe9d3c..c247f0f 100644 --- a/lib/init_server.py +++ b/lib/init_server.py @@ -86,7 +86,7 @@ async def work_init(loader_event, init_token): log.error("Cuda must be version 11.7+") os._exit(1) - machine_specs = await specs.get(benchmark_internet=True) + machine_specs = await specs.get(benchmark_internet=True, benchmark_disk=True, require_same_gpus=(not config.allow_mixed_gpus)) loader_event.clear() complete_loader()