This commit is contained in:
huchenlei 2024-12-08 21:01:41 -05:00
parent 10e08b0554
commit 73b26e5375
2 changed files with 22 additions and 18 deletions

View File

@ -5,20 +5,24 @@ import ctypes
import logging import logging
torch_spec = importlib.util.find_spec("torch") def fix_pytorch_libomp():
for folder in torch_spec.submodule_search_locations: """
lib_folder = os.path.join(folder, "lib") Fix PyTorch libomp DLL issue on Windows by copying the correct DLL file if needed.
test_file = os.path.join(lib_folder, "fbgemm.dll") """
dest = os.path.join(lib_folder, "libomp140.x86_64.dll") torch_spec = importlib.util.find_spec("torch")
if os.path.exists(dest): for folder in torch_spec.submodule_search_locations:
break lib_folder = os.path.join(folder, "lib")
test_file = os.path.join(lib_folder, "fbgemm.dll")
with open(test_file, 'rb') as f: dest = os.path.join(lib_folder, "libomp140.x86_64.dll")
contents = f.read() if os.path.exists(dest):
if b"libomp140.x86_64.dll" not in contents:
break break
try:
mydll = ctypes.cdll.LoadLibrary(test_file) with open(test_file, "rb") as f:
except FileNotFoundError as e: contents = f.read()
logging.warning("Detected pytorch version with libomp issue, patching.") if b"libomp140.x86_64.dll" not in contents:
shutil.copyfile(os.path.join(lib_folder, "libiomp5md.dll"), dest) break
try:
mydll = ctypes.cdll.LoadLibrary(test_file)
except FileNotFoundError as e:
logging.warning("Detected pytorch version with libomp issue, patching.")
shutil.copyfile(os.path.join(lib_folder, "libiomp5md.dll"), dest)

View File

@ -86,9 +86,9 @@ if __name__ == "__main__":
import cuda_malloc import cuda_malloc
if args.windows_standalone_build: if args.windows_standalone_build:
# TODO: Convert fix_torch to a function.
try: try:
import fix_torch # noqa: F401 from fix_torch import fix_pytorch_libomp
fix_pytorch_libomp()
except: except:
pass pass