diff --git a/app/venv_management.py b/app/venv_management.py new file mode 100644 index 000000000..412c332af --- /dev/null +++ b/app/venv_management.py @@ -0,0 +1,125 @@ +import torch +import torchvision +import torchaudio +from dataclasses import dataclass + +import importlib +if importlib.util.find_spec("torch_directml"): + from pip._vendor import pkg_resources + + +class VEnvException(Exception): + pass + + +@dataclass +class TorchVersionInfo: + name: str = None + version: str = None + extension: str = None + is_nightly: bool = False + is_cpu: bool = False + is_cuda: bool = False + is_xpu: bool = False + is_rocm: bool = False + is_directml: bool = False + + +def get_bootstrap_requirements_string(): + ''' + Get string to insert into a 'pip install' command to get the same torch dependencies as current venv. + ''' + torch_info = get_torch_info(torch) + packages = [torchvision, torchaudio] + infos = [torch_info] + [get_torch_info(x) for x in packages] + # directml should be first dependency, if exists + directml_info = get_torch_directml_info() + if directml_info is not None: + infos = [directml_info] + infos + # create list of strings to combine into install string + install_str_list = [] + for info in infos: + info_string = f"{info.name}=={info.version}" + if not info.is_cpu and not info.is_directml: + info_string = f"{info_string}+{info.extension}" + install_str_list.append(info_string) + # handle extra_index_url, if needed + extra_index_url = get_index_url(torch_info) + if extra_index_url: + install_str_list.append(extra_index_url) + # format nightly install properly + if torch_info.is_nightly: + install_str_list = ["--pre"] + install_str_list + + install_str = " ".join(install_str_list) + return install_str + +def get_index_url(info: TorchVersionInfo=None): + ''' + Get --extra-index-url (or --index-url) for torch install. + ''' + if info is None: + info = get_torch_info() + # for cpu, don't need any index_url + if info.is_cpu and not info.is_nightly: + return None + # otherwise, format index_url + base_url = "https://download.pytorch.org/whl/" + if info.is_nightly: + base_url = f"--index-url {base_url}nightly/" + else: + base_url = f"--extra-index-url {base_url}" + base_url = f"{base_url}{info.extension}" + return base_url + +def get_torch_info(package=None): + ''' + Get info about an installed torch-related package. + ''' + if package is None: + package = torch + info = TorchVersionInfo(name=package.__name__) + info.version = package.__version__ + info.extension = None + info.is_nightly = False + # get extension, separate from version + info.version, info.extension = info.version.split('+', 1) + if info.extension.startswith('cpu'): + info.is_cpu = True + elif info.extension.startswith('cu'): + info.is_cuda = True + elif info.extension.startswith('rocm'): + info.is_rocm = True + elif info.extension.startswith('xpu'): + info.is_xpu = True + # TODO: add checks for some odd pytorch versions, if possible + + # check if nightly install + if 'dev' in info.version: + info.is_nightly = True + + return info + +def get_torch_directml_info(): + ''' + Get info specifically about torch-directml package. + + Returns None if torch-directml is not installed. + ''' + # the import string and the pip string are different + pip_name = "torch-directml" + # if no torch_directml, do nothing + if not importlib.util.find_spec("torch_directml"): + return None + info = TorchVersionInfo(name=pip_name) + info.is_directml = True + for p in pkg_resources.working_set: + if p.project_name.lower() == pip_name: + info.version = p.version + if p.version is None: + return None + return info + + +if __name__ == '__main__': + print(get_bootstrap_requirements_string())