diff --git a/clore_hosting/main.py b/clore_hosting/main.py index 5779abd..f391f78 100644 --- a/clore_hosting/main.py +++ b/clore_hosting/main.py @@ -41,9 +41,9 @@ async def configure_networks(containers): except Exception as e: return False -async def deploy_containers(validated_containers): +async def deploy_containers(validated_containers, allowed_running_containers): try: - all_running_container_names, all_stopped_container_names = await asyncio.to_thread(docker_deploy.deploy, validated_containers) + all_running_container_names, all_stopped_container_names = await asyncio.to_thread(docker_deploy.deploy, validated_containers, allowed_running_containers) return types.DeployContainersRes(all_running_container_names=all_running_container_names, all_stopped_container_names=all_stopped_container_names) except Exception as e: return False @@ -122,6 +122,7 @@ class CloreClient: nvml.init(allow_hive_binaries=not self.dont_use_hive_binaries) self.extra_allowed_images = utils.get_extra_allowed_images() + self.allowed_running_containers = utils.get_allowed_container_names() self.gpu_oc_specs = nvml.get_gpu_oc_specs() self.last_oc_service_submit = 0 @@ -140,7 +141,7 @@ class CloreClient: task1 = asyncio.create_task(self.main(pull_list, monitoring)) task2 = asyncio.create_task(self.handle_container_cache(pull_list, monitoring)) task3 = asyncio.create_task(self.startup_script_runner(monitoring)) - task4 = asyncio.create_task(log_streaming_task.log_streaming_task(container_log_broken, monitoring)) + task4 = asyncio.create_task(log_streaming_task.log_streaming_task(container_log_broken, monitoring, self.allowed_running_containers)) task5 = asyncio.create_task(self.container_log_streaming_service(monitoring)) task6 = asyncio.create_task(self.specs_service(monitoring)) task7 = asyncio.create_task(self.oc_service(monitoring)) @@ -397,7 +398,7 @@ class CloreClient: tasks.append(WebSocketClient.stream_pull_logs()) if self.validated_containers_set: - tasks.append(deploy_containers(self.validated_containers)) + tasks.append(deploy_containers(self.validated_containers, self.allowed_running_containers)) if step==1: WebSocketClient.set_auth(self.auth_key) diff --git a/lib/docker_deploy.py b/lib/docker_deploy.py index 83838b2..16b7f1b 100644 --- a/lib/docker_deploy.py +++ b/lib/docker_deploy.py @@ -11,7 +11,7 @@ client = docker_interface.client config = config_module.config log = logging_lib.log -def deploy(validated_containers): +def deploy(validated_containers, allowed_running_containers=[]): local_images = docker_interface.get_local_images() all_containers = docker_interface.get_containers(all=True) @@ -166,13 +166,13 @@ def deploy(validated_containers): container.stop() except Exception as e: pass - elif container.name not in paused_names+needed_running_names and container.status == 'running': + elif container.name not in paused_names+needed_running_names+allowed_running_containers and container.status == 'running': try: container.stop() container.remove() except Exception as e: pass - elif container.name not in paused_names+needed_running_names: + elif container.name not in paused_names+needed_running_names+allowed_running_containers: try: container.remove() except Exception as e: diff --git a/lib/log_streaming_task.py b/lib/log_streaming_task.py index 66522bb..97073da 100644 --- a/lib/log_streaming_task.py +++ b/lib/log_streaming_task.py @@ -10,7 +10,7 @@ from lib import container_logs from concurrent.futures import ThreadPoolExecutor import queue # Import the synchronous queue module -async def log_streaming_task(message_broker, monitoring): +async def log_streaming_task(message_broker, monitoring, do_not_stream_containers): client = docker_interface.client executor = ThreadPoolExecutor(max_workers=4) tasks = {} @@ -29,14 +29,15 @@ async def log_streaming_task(message_broker, monitoring): # Start tasks for new containers for container_name, container in current_containers.items(): - log_container_names.append(container_name) - if container_name not in tasks: - log.debug(f"log_streaming_task() | Starting task for {container_name}") - sync_queue = queue.Queue() - task = asyncio.ensure_future(asyncio.get_event_loop().run_in_executor( - executor, container_logs.stream_logs, container_name, sync_queue)) - tasks[container_name] = task - queues[container_name] = sync_queue + if not container_name in do_not_stream_containers: + log_container_names.append(container_name) + if container_name not in tasks: + log.debug(f"log_streaming_task() | Starting task for {container_name}") + sync_queue = queue.Queue() + task = asyncio.ensure_future(asyncio.get_event_loop().run_in_executor( + executor, container_logs.stream_logs, container_name, sync_queue)) + tasks[container_name] = task + queues[container_name] = sync_queue await message_broker.put(log_container_names) diff --git a/lib/utils.py b/lib/utils.py index a34347c..695c8e9 100644 --- a/lib/utils.py +++ b/lib/utils.py @@ -49,6 +49,12 @@ def get_auth(): return auth_str except Exception as e: return '' + +def get_allowed_container_names(): + allowed_container_names = os.getenv("ALLOWED_CONTAINER_NAMES") + if type(allowed_container_names)==str and len(allowed_container_names)>0: + return [x for x in allowed_container_names.split(',') if x] + return [] def unix_timestamp(): return int(time.time())