This commit is contained in:
huchenlei 2024-12-08 21:01:41 -05:00
parent 10e08b0554
commit 73b26e5375
2 changed files with 22 additions and 18 deletions

View File

@ -5,6 +5,10 @@ import ctypes
import logging import logging
def fix_pytorch_libomp():
"""
Fix PyTorch libomp DLL issue on Windows by copying the correct DLL file if needed.
"""
torch_spec = importlib.util.find_spec("torch") torch_spec = importlib.util.find_spec("torch")
for folder in torch_spec.submodule_search_locations: for folder in torch_spec.submodule_search_locations:
lib_folder = os.path.join(folder, "lib") lib_folder = os.path.join(folder, "lib")
@ -13,7 +17,7 @@ for folder in torch_spec.submodule_search_locations:
if os.path.exists(dest): if os.path.exists(dest):
break break
with open(test_file, 'rb') as f: with open(test_file, "rb") as f:
contents = f.read() contents = f.read()
if b"libomp140.x86_64.dll" not in contents: if b"libomp140.x86_64.dll" not in contents:
break break

View File

@ -86,9 +86,9 @@ if __name__ == "__main__":
import cuda_malloc import cuda_malloc
if args.windows_standalone_build: if args.windows_standalone_build:
# TODO: Convert fix_torch to a function.
try: try:
import fix_torch # noqa: F401 from fix_torch import fix_pytorch_libomp
fix_pytorch_libomp()
except: except:
pass pass