from aiofiles.os import stat as aio_stat
from pydantic import BaseModel, Field, constr
import xml.etree.ElementTree as ET
from lib import docker_interface
from typing import Dict, List, Optional
from lib import utils
import subprocess
import speedtest
import platform
import aiofiles
import aiohttp
import asyncio
import shutil
import psutil
import time
import sys
import os
import re

class NvidiaVersionInfo(BaseModel):
    driver_version: str
    cuda_version: str

class PCIBusInfo(BaseModel):
    width: int = Field(None, description="The width of the PCI bus")
    revision: int = Field(None, description="The revision number of the PCI device", ge=0)

# Example usage with None values
example_pci_bus_info = PCIBusInfo()
#print(example_pci_bus_info)

async def get_cpu_usage():
    loop = asyncio.get_running_loop()
    return await loop.run_in_executor(None, psutil.cpu_percent, 1)

async def get_ram_usage():
    loop = asyncio.get_running_loop()
    return await loop.run_in_executor(None, psutil.virtual_memory)

def get_kernel():
    return platform.uname().release

def is_hive():
    return "hive" in get_kernel()

def get_total_ram_mb():
    total_ram = psutil.virtual_memory().total
    return total_ram / (1024 ** 2)

def get_os_release():
    try:
        with open("/etc/os-release") as f:
            os_info = f.read()
        os_release = {}
        for line in os_info.split('\n'):
            if '=' in line:
                key, value = line.split('=', 1)
                if value[:1]=='"' and value.endswith('"'):
                    value = value[1:len(value)-1]
                    os_release[key]=value

        needed_cgroupfs_versions = ["22.04", "22.10"] # Mitigate issue https://github.com/NVIDIA/nvidia-docker/issues/1730

        if "NAME" in os_release and "VERSION_ID" in os_release:
            if os_release["NAME"].lower() == "ubuntu" and os_release["VERSION_ID"] in needed_cgroupfs_versions:
                os_release["use_cgroupfs"]=True

        return os_release
    except Exception as e:
        return {}

def drop_caches():
    try:
        with open('/proc/sys/vm/drop_caches', 'w') as f:
            f.write('3\n')
    except Exception as e:
        pass

def write_test(file_path, block_size, num_blocks):
    data = os.urandom(block_size)
    total_bytes = block_size * num_blocks

    start_time = time.time()
    
    with open(file_path, 'wb') as f:
        for _ in range(num_blocks):
            f.write(data)
            f.flush()
            os.fsync(f.fileno())
    
    elapsed_time = time.time() - start_time
    write_speed = total_bytes / elapsed_time / (1024 * 1024)
    
    return write_speed, elapsed_time

def read_test(file_path, block_size, num_blocks):
    total_bytes = block_size * num_blocks

    # Drop caches to avoid OS-level caching effects
    drop_caches()
    
    start_time = time.time()
    
    with open(file_path, 'rb') as f:
        for _ in range(num_blocks):
            data = f.read(block_size)
            if not data:
                break
    
    elapsed_time = time.time() - start_time
    read_speed = total_bytes / elapsed_time / (1024 * 1024)
    
    return read_speed, elapsed_time

def disk_benchmark():
    total, used, free = shutil.disk_usage("/")

    free_gb = free/1024/1024/1024

    if free_gb<1:
        return 0,0
    
    block_size = 1024*1024
    num_blocks = 250 if free_gb < 3 else 1500
    
    file_path="/tmp/output"
    
    print("Running disk benchmark...")
    print(f"Block Size: {block_size} bytes, Number of Blocks: {num_blocks}")

    # Run write test
    write_speed, write_time = write_test(file_path, block_size, num_blocks)
    print(f"Write Speed: {write_speed:.2f} MB/s, Time: {write_time:.2f} seconds")

    # Run read test
    read_speed, read_time = read_test(file_path, block_size, num_blocks)
    print(f"Read Speed: {read_speed:.2f} MB/s, Time: {read_time:.2f} seconds")

    # Cleanup
    os.remove(file_path)
    return float(round(write_speed,2)), float(round(read_speed,2))

def get_nvidia_version():
    try:
        output = subprocess.check_output(['nvidia-smi', '-x', '-q'], encoding='utf-8')
        root = ET.fromstring(output)
        driver_version = root.find('driver_version').text
        cuda_version = root.find('.//cuda_version').text
        if driver_version and cuda_version:
            return NvidiaVersionInfo(driver_version=driver_version, cuda_version=cuda_version)
        else:
            return NvidiaVersionInfo()
    except Exception as e:
        return NvidiaVersionInfo()

async def measure_internet_speed():
    try:
        st = speedtest.Speedtest()
        server = st.get_best_server()
        country = server['cc']

        loop = asyncio.get_event_loop()
        download_speed = await loop.run_in_executor(None, st.download)
        upload_speed = await loop.run_in_executor(None, st.upload)

        return country, download_speed/1024/1024, upload_speed/1024/1024
    except Exception as e:
        return '',0, 0
    
async def disk_speed():
    try:
        loop = asyncio.get_event_loop()
        write_speed, read_speed = await loop.run_in_executor(None, disk_benchmark)
        return write_speed, read_speed
    except Exception as e:
        print("disk benchmark exception",e)
        return 0, 0
    
async def get_country_code():
    async with aiohttp.ClientSession() as session:
        try:
            # Set a timeout for the request
            async with session.get("https://ifconfig.io/all.json", timeout=5) as response:
                # Check if the request was successful
                if response.status == 200:
                    data = await response.json()
                    # Return the country code
                    return data.get("country_code")
                else:
                    return f"Error: Response status {response.status}"
        except asyncio.TimeoutError:
            return "Error: The request timed out after 5 seconds"

def filter_non_numeric(input_string):
    return re.sub(r'[^0-9]', '', input_string)

def get_disk_udevadm(mount_point='/'):
    try:
        find_mnt_return_code, find_mnt_stdout, find_mnt_stderr = utils.run_command(f"findmnt -n -o SOURCE {mount_point}")
        if find_mnt_return_code!=0 or find_mnt_stderr!='':
            return ''
        lsblk_return_code, lsblk_stdout, lsblk_stderr = utils.run_command(f"lsblk -no pkname {find_mnt_stdout}")
        if lsblk_return_code!=0 or lsblk_stderr!='':
            return ''
        if lsblk_stdout[:5]!="/dev/":
            lsblk_stdout=f"/dev/{lsblk_stdout}"
        udevadm_return_code, udevadm_stdout, udevadm_stderr = utils.run_command(f"udevadm info --query=all --name={lsblk_stdout}")
        if udevadm_return_code!=0 or udevadm_stderr!='':
            return ''
        return udevadm_stdout
    except Exception as e:
        return ''

def get_bus_spec(bus_id):
    try:
        with open(f"/sys/bus/pci/devices/{bus_id}/current_link_speed", "r", encoding="utf-8") as file:
            current_link_speed = file.read().strip()
        with open(f"/sys/bus/pci/devices/{bus_id}/current_link_width", "r", encoding="utf-8") as file:
            current_link_width = file.read().strip()

        speed_to_rev_mapping = {
            "128": 7,
            "64": 6,
            "32": 5,
            "16": 4,
            "8": 3,
            "5.0": 2,
        }

        pci_revision = 1  # Default value
        current_link_width=int(current_link_width)

        # Iterate over the mapping and update pci_rev based on the pcie_speed
        for speed_pattern, rev in speed_to_rev_mapping.items():
            if speed_pattern in current_link_speed:
                pci_revision = rev

        return PCIBusInfo(revision=pci_revision, width=current_link_width)
    except Exception as e:
        print(e)
        return PCIBusInfo()

def get_gpu_info():
    gpu_str = "0x Unknown"
    nvml_err = False
    gpu_mem = 0
    gpus={
        "nvidia":[],
        "amd":[] # To be implemented in future releases
    }

    valid_pci_dev_list = []

    try:
        valid_pci_dev_list = os.listdir("/sys/bus/pci/devices")
    except Exception as e:
        pass

    nvidia_smi_return_code, nvidia_smi_stdout, nvidia_smi_stderr = utils.run_command(f"nvidia-smi --query-gpu=index,name,uuid,serial,memory.total --format=csv")
    nvidia_smi_xl_return_code, nvidia_smi_xl_stdout, nvidia_smi_xl_stderr = utils.run_command("nvidia-smi --query-gpu=timestamp,name,pci.bus_id,driver_version,pstate,pcie.link.gen.max,pcie.link.gen.current,temperature.gpu,utilization.gpu,utilization.memory,memory.total,memory.free,memory.used --format=csv")

    if "Failed to initialize NVML" in nvidia_smi_stdout or "Failed to initialize NVML" in nvidia_smi_stderr or "Failed to initialize NVML" in nvidia_smi_xl_stdout or "Failed to initialize NVML" in nvidia_smi_xl_stderr:
        nvml_err=True
    elif nvidia_smi_return_code==0 and nvidia_smi_xl_return_code==0:
        try:
            lines_xl = nvidia_smi_xl_stdout.split('\n')
            for index, line in enumerate(lines_xl):
                parts = [s.strip() for s in line.split(',')]
                if len(parts)>12 and index>0:
                    xl_gpu_info={
                        "id":index-1,
                        "timestamp": parts[0],
                        "name": parts[1],
                        "pcie_bus": parts[2].split(':', 1)[1],
                        "driver": parts[3],
                        "pstate": parts[4],
                        "temp": parts[7],
                        "core_utilization": int(parts[8].replace(" %",'')),
                        "mem_utilization": int(parts[9].replace(" %",'')),
                        "mem_total": parts[10],
                        "mem_free": parts[11],
                        "mem_used": parts[12]
                    }
                    try:
                        pci_query = parts[2][parts[2].find(':')+1:]
                        for index, valid_pci_dev in enumerate(valid_pci_dev_list):
                            if pci_query.lower() in valid_pci_dev.lower():
                                bus_spec = get_bus_spec(valid_pci_dev)
                                if bus_spec.width!=None and bus_spec.revision!=None:
                                    xl_gpu_info["pcie_width"]=bus_spec.width
                                    xl_gpu_info["pcie_revision"]=bus_spec.revision
                    except Exception as e:
                        pass
                    gpus["nvidia"].append(xl_gpu_info)
            lines = nvidia_smi_stdout.split('\n')
            for line in lines:
                parts = line.split(',')
                if bool(re.match(r'^[0-9]+$', parts[0])):
                    gpu_str = f"{len(lines)-1}x {parts[1].strip()}"
                    gpu_mem = round(int(filter_non_numeric(parts[4]).strip())/1024, 2)
        except Exception as e:
            nvml_err=True
            pass
    else:
        nvml_err=True

    return gpu_str, gpu_mem, gpus, nvml_err



class DockerDaemonConfig(BaseModel):
    data_root: str = Field(alias="data-root")
    storage_driver: str = Field(alias="storage-driver")
    storage_opts: Optional[List[str]] = Field(alias="storage-opts")

class Specs:
    def __init__(self):
        self.motherboard_name_file = "/sys/devices/virtual/dmi/id/board_name"

    async def get(self, benchmark_internet=False, benchmark_disk=False, require_same_gpus=False):
        total_threads, total_cores, model_name = self.get_cpu_info()
        gpu_str, gpu_mem, gpus, nvml_err = get_gpu_info()
        if require_same_gpus:
            last_gpu_name=''
            for gpu in gpus["nvidia"]:
                if not last_gpu_name:
                    last_gpu_name=gpu["name"]
                elif last_gpu_name!=gpu["name"]:
                    print("\033[31mMixed GPU machines are not allowed\033[0m")
                    sys.exit(1)
        docker_daemon_config = docker_interface.get_daemon_config()
        disk_str=""
        data_root_location="main_disk"
        if docker_daemon_config==None or type(docker_daemon_config)!=dict:
            sys.exit(1)
        else:
            overlay_total_size=None
            disk_type=""
            disk_usage_source_path = '/'
            try:
                if "storage-driver" in docker_daemon_config and docker_daemon_config["storage-driver"] == "overlay2" and "data-root" in docker_daemon_config:
                    disk_usage_source_path = docker_daemon_config["data-root"]
            except Exception as e:
                pass
            if overlay_total_size==None:
                total, used, free = shutil.disk_usage(disk_usage_source_path)
                disk_udevadm = get_disk_udevadm("/")
                for udevadm_line in disk_udevadm.split('\n'):
                    try:
                        key, value=udevadm_line.split('=',1)
                        if "id_model" in key.lower():
                            disk_type=value[:24]
                        elif "devpath" in key.lower() and "/virtual/" in value:
                            disk_type="Virtual"
                    except Exception as e_int:
                        pass
                disk_str = f"{disk_type} {round(free / (1024**3), 4)}GB"
            else: # Disk is overlay
                data_root_location="separate"
                disk_str = f"{disk_type} {overlay_total_size}GB"

        if benchmark_disk:
            disk_speeds = await disk_speed()
        else:
            disk_speeds = [0,0]

        response = {
            "mb": await self.motherboard_type(),
            "cpu":model_name,
            "cpus":f"{total_cores}/{total_threads}",
            "ram": self.get_ram_size(),
            "swap": self.get_swap_size(),
            "data_root_location":data_root_location,
            "disk": disk_str,
            "disk_speed": disk_speeds[1],
            "gpu":gpu_str,
            "gpuram": gpu_mem,
            "gpus": gpus,
            "nvml_error":nvml_err
        }
        if benchmark_internet:
            country, download_speed, upload_speed = await measure_internet_speed()
            if country=='':
                download_speed=0
                upload_speed=0
            possible_cc = await get_country_code()
            if len(possible_cc)<4:
                country=possible_cc

            response["net"]={
                "cc":country,
                "down":download_speed,
                "up":upload_speed
            }

        return response

    async def read_file(self, file_path):
        try:
            async with aiofiles.open(file_path, mode='r') as file:
                contents = await file.read()
                return contents
        except Exception as e:
            return None
            

    async def check_file_existence(self, file_path):
        try:
            await aio_stat(file_path)
            return True
        except Exception as e:
            return False

    async def motherboard_type(self):
        if await self.check_file_existence(self.motherboard_name_file):
            motherboard_type = await self.read_file(self.motherboard_name_file)
            return motherboard_type.replace('\n','')[:32] if motherboard_type!=None else "Unknown"
        else:
            return "Unknown"
        
    def get_cpu_info(self):
        lscpu_out = subprocess.check_output(['lscpu']).decode('utf-8')
        threads_per_code=1
        total_threads = os.cpu_count()
        model_name = "Unknown CPU"
        for line in lscpu_out.split('\n'):
            try:
                key, value = line.split(':', 1)
                value=value.strip(' ')
                #print(key,'|',value)
                if "model name" in key.lower():
                    model_name=value
                elif "Thread(s) per core" == key and int(value):
                    threads_per_code=int(value)
            except Exception as e:
                pass
        total_cores = int(total_threads/threads_per_code)
        return total_threads, total_cores, model_name
    
    def get_ram_size(self):
        try:
            with open('/proc/meminfo', 'r') as f:
                lines = f.readlines()
                for line in lines:
                    if line.startswith('MemTotal'):
                        total_memory_kb = int(line.split()[1])
                        total_memory_gb = total_memory_kb / (1024) / 1000  # Convert KB to GB
                        return round(total_memory_gb, 4)
        except Exception as e:
            return 0
        
    def get_swap_size(self):
        try:
            with open('/proc/meminfo', 'r') as f:
                lines = f.readlines()
                for line in lines:
                    if line.startswith('SwapTotal'):
                        total_swap_kb = int(line.split()[1])
                        total_swap_gb = total_swap_kb / (1024) / 1000  # Convert KB to GB
                        return round(total_swap_gb, 4)
        except Exception as e:
            return 0

def get_root_device():
    try:
        mount_info = subprocess.check_output(['findmnt', '-n', '-o', 'SOURCE', '/']).decode().strip()
        return mount_info
    except subprocess.CalledProcessError:
        return None

def is_usb_device(device):
    try:
        lsblk_output = subprocess.check_output(['lsblk', '-o', 'NAME,TRAN', '-n']).decode().strip()
        for line in lsblk_output.splitlines():
            parts = line.split()
            if len(parts) == 2 and device.endswith(parts[0]):
                return parts[1] == 'usb'
    except subprocess.CalledProcessError:
        return True
    return False