Fix cond_cat to not try to cast anything that doesn't have a 'to' function

This commit is contained in:
Jedrzej Kosinski 2025-01-10 23:05:24 -06:00
parent d3cf2b7b24
commit e88c6c03ff

View File

@ -143,7 +143,7 @@ def cond_cat(c_list, device=None):
for k in temp:
conds = temp[k]
out[k] = conds[0].concat(conds[1:])
if device is not None:
if device is not None and hasattr(out[k], 'to'):
out[k] = out[k].to(device)
return out