diff --git a/main.py b/main.py index 209128f15..bc0af3dd4 100644 --- a/main.py +++ b/main.py @@ -6,6 +6,10 @@ import threading import queue import traceback +if '--dont-upcast-attention' in sys.argv: + print("disabling upcasting of attention") + os.environ['ATTN_PRECISION'] = "fp16" + import torch import nodes