ALLOWED_CONTAINER_NAMES env variable

This commit is contained in:
clore 2024-09-04 12:18:23 +00:00
parent 590dc4b65e
commit cab037526a
4 changed files with 24 additions and 16 deletions

View File

@ -41,9 +41,9 @@ async def configure_networks(containers):
except Exception as e: except Exception as e:
return False return False
async def deploy_containers(validated_containers): async def deploy_containers(validated_containers, allowed_running_containers):
try: try:
all_running_container_names, all_stopped_container_names = await asyncio.to_thread(docker_deploy.deploy, validated_containers) all_running_container_names, all_stopped_container_names = await asyncio.to_thread(docker_deploy.deploy, validated_containers, allowed_running_containers)
return types.DeployContainersRes(all_running_container_names=all_running_container_names, all_stopped_container_names=all_stopped_container_names) return types.DeployContainersRes(all_running_container_names=all_running_container_names, all_stopped_container_names=all_stopped_container_names)
except Exception as e: except Exception as e:
return False return False
@ -122,6 +122,7 @@ class CloreClient:
nvml.init(allow_hive_binaries=not self.dont_use_hive_binaries) nvml.init(allow_hive_binaries=not self.dont_use_hive_binaries)
self.extra_allowed_images = utils.get_extra_allowed_images() self.extra_allowed_images = utils.get_extra_allowed_images()
self.allowed_running_containers = utils.get_allowed_container_names()
self.gpu_oc_specs = nvml.get_gpu_oc_specs() self.gpu_oc_specs = nvml.get_gpu_oc_specs()
self.last_oc_service_submit = 0 self.last_oc_service_submit = 0
@ -140,7 +141,7 @@ class CloreClient:
task1 = asyncio.create_task(self.main(pull_list, monitoring)) task1 = asyncio.create_task(self.main(pull_list, monitoring))
task2 = asyncio.create_task(self.handle_container_cache(pull_list, monitoring)) task2 = asyncio.create_task(self.handle_container_cache(pull_list, monitoring))
task3 = asyncio.create_task(self.startup_script_runner(monitoring)) task3 = asyncio.create_task(self.startup_script_runner(monitoring))
task4 = asyncio.create_task(log_streaming_task.log_streaming_task(container_log_broken, monitoring)) task4 = asyncio.create_task(log_streaming_task.log_streaming_task(container_log_broken, monitoring, self.allowed_running_containers))
task5 = asyncio.create_task(self.container_log_streaming_service(monitoring)) task5 = asyncio.create_task(self.container_log_streaming_service(monitoring))
task6 = asyncio.create_task(self.specs_service(monitoring)) task6 = asyncio.create_task(self.specs_service(monitoring))
task7 = asyncio.create_task(self.oc_service(monitoring)) task7 = asyncio.create_task(self.oc_service(monitoring))
@ -397,7 +398,7 @@ class CloreClient:
tasks.append(WebSocketClient.stream_pull_logs()) tasks.append(WebSocketClient.stream_pull_logs())
if self.validated_containers_set: if self.validated_containers_set:
tasks.append(deploy_containers(self.validated_containers)) tasks.append(deploy_containers(self.validated_containers, self.allowed_running_containers))
if step==1: if step==1:
WebSocketClient.set_auth(self.auth_key) WebSocketClient.set_auth(self.auth_key)

View File

@ -11,7 +11,7 @@ client = docker_interface.client
config = config_module.config config = config_module.config
log = logging_lib.log log = logging_lib.log
def deploy(validated_containers): def deploy(validated_containers, allowed_running_containers=[]):
local_images = docker_interface.get_local_images() local_images = docker_interface.get_local_images()
all_containers = docker_interface.get_containers(all=True) all_containers = docker_interface.get_containers(all=True)
@ -166,13 +166,13 @@ def deploy(validated_containers):
container.stop() container.stop()
except Exception as e: except Exception as e:
pass pass
elif container.name not in paused_names+needed_running_names and container.status == 'running': elif container.name not in paused_names+needed_running_names+allowed_running_containers and container.status == 'running':
try: try:
container.stop() container.stop()
container.remove() container.remove()
except Exception as e: except Exception as e:
pass pass
elif container.name not in paused_names+needed_running_names: elif container.name not in paused_names+needed_running_names+allowed_running_containers:
try: try:
container.remove() container.remove()
except Exception as e: except Exception as e:

View File

@ -10,7 +10,7 @@ from lib import container_logs
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
import queue # Import the synchronous queue module import queue # Import the synchronous queue module
async def log_streaming_task(message_broker, monitoring): async def log_streaming_task(message_broker, monitoring, do_not_stream_containers):
client = docker_interface.client client = docker_interface.client
executor = ThreadPoolExecutor(max_workers=4) executor = ThreadPoolExecutor(max_workers=4)
tasks = {} tasks = {}
@ -29,14 +29,15 @@ async def log_streaming_task(message_broker, monitoring):
# Start tasks for new containers # Start tasks for new containers
for container_name, container in current_containers.items(): for container_name, container in current_containers.items():
log_container_names.append(container_name) if not container_name in do_not_stream_containers:
if container_name not in tasks: log_container_names.append(container_name)
log.debug(f"log_streaming_task() | Starting task for {container_name}") if container_name not in tasks:
sync_queue = queue.Queue() log.debug(f"log_streaming_task() | Starting task for {container_name}")
task = asyncio.ensure_future(asyncio.get_event_loop().run_in_executor( sync_queue = queue.Queue()
executor, container_logs.stream_logs, container_name, sync_queue)) task = asyncio.ensure_future(asyncio.get_event_loop().run_in_executor(
tasks[container_name] = task executor, container_logs.stream_logs, container_name, sync_queue))
queues[container_name] = sync_queue tasks[container_name] = task
queues[container_name] = sync_queue
await message_broker.put(log_container_names) await message_broker.put(log_container_names)

View File

@ -49,6 +49,12 @@ def get_auth():
return auth_str return auth_str
except Exception as e: except Exception as e:
return '' return ''
def get_allowed_container_names():
allowed_container_names = os.getenv("ALLOWED_CONTAINER_NAMES")
if type(allowed_container_names)==str and len(allowed_container_names)>0:
return [x for x in allowed_container_names.split(',') if x]
return []
def unix_timestamp(): def unix_timestamp():
return int(time.time()) return int(time.time())