mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Lint and fix undefined names (1/N) (#6028)
This commit is contained in:
parent
60749f345d
commit
2cddbf0821
@ -1,3 +1,4 @@
|
|||||||
|
import math
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from .ldm.modules.attention import CrossAttention
|
from .ldm.modules.attention import CrossAttention
|
||||||
|
@ -228,9 +228,9 @@ class FeedForward(nn.Module):
|
|||||||
linear_in = GLU(dim, inner_dim, activation, dtype=dtype, device=device, operations=operations)
|
linear_in = GLU(dim, inner_dim, activation, dtype=dtype, device=device, operations=operations)
|
||||||
else:
|
else:
|
||||||
linear_in = nn.Sequential(
|
linear_in = nn.Sequential(
|
||||||
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
||||||
operations.Linear(dim, inner_dim, bias = not no_bias, dtype=dtype, device=device) if not use_conv else operations.Conv1d(dim, inner_dim, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias, dtype=dtype, device=device),
|
operations.Linear(dim, inner_dim, bias = not no_bias, dtype=dtype, device=device) if not use_conv else operations.Conv1d(dim, inner_dim, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias, dtype=dtype, device=device),
|
||||||
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
||||||
activation
|
activation
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -245,9 +245,9 @@ class FeedForward(nn.Module):
|
|||||||
|
|
||||||
self.ff = nn.Sequential(
|
self.ff = nn.Sequential(
|
||||||
linear_in,
|
linear_in,
|
||||||
Rearrange('b d n -> b n d') if use_conv else nn.Identity(),
|
rearrange('b d n -> b n d') if use_conv else nn.Identity(),
|
||||||
linear_out,
|
linear_out,
|
||||||
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
import logging
|
||||||
|
import math
|
||||||
import torch
|
import torch
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any, Dict, Tuple, Union
|
from typing import Any, Dict, Tuple, Union
|
||||||
@ -52,7 +54,7 @@ class AbstractAutoencoder(torch.nn.Module):
|
|||||||
|
|
||||||
if self.use_ema:
|
if self.use_ema:
|
||||||
self.model_ema = LitEma(self, decay=ema_decay)
|
self.model_ema = LitEma(self, decay=ema_decay)
|
||||||
logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
logging.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
||||||
|
|
||||||
def get_input(self, batch) -> Any:
|
def get_input(self, batch) -> Any:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
@ -68,14 +70,14 @@ class AbstractAutoencoder(torch.nn.Module):
|
|||||||
self.model_ema.store(self.parameters())
|
self.model_ema.store(self.parameters())
|
||||||
self.model_ema.copy_to(self)
|
self.model_ema.copy_to(self)
|
||||||
if context is not None:
|
if context is not None:
|
||||||
logpy.info(f"{context}: Switched to EMA weights")
|
logging.info(f"{context}: Switched to EMA weights")
|
||||||
try:
|
try:
|
||||||
yield None
|
yield None
|
||||||
finally:
|
finally:
|
||||||
if self.use_ema:
|
if self.use_ema:
|
||||||
self.model_ema.restore(self.parameters())
|
self.model_ema.restore(self.parameters())
|
||||||
if context is not None:
|
if context is not None:
|
||||||
logpy.info(f"{context}: Restored training weights")
|
logging.info(f"{context}: Restored training weights")
|
||||||
|
|
||||||
def encode(self, *args, **kwargs) -> torch.Tensor:
|
def encode(self, *args, **kwargs) -> torch.Tensor:
|
||||||
raise NotImplementedError("encode()-method of abstract base class called")
|
raise NotImplementedError("encode()-method of abstract base class called")
|
||||||
@ -84,7 +86,7 @@ class AbstractAutoencoder(torch.nn.Module):
|
|||||||
raise NotImplementedError("decode()-method of abstract base class called")
|
raise NotImplementedError("decode()-method of abstract base class called")
|
||||||
|
|
||||||
def instantiate_optimizer_from_config(self, params, lr, cfg):
|
def instantiate_optimizer_from_config(self, params, lr, cfg):
|
||||||
logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config")
|
logging.info(f"loading >>> {cfg['target']} <<< optimizer from config")
|
||||||
return get_obj_from_str(cfg["target"])(
|
return get_obj_from_str(cfg["target"])(
|
||||||
params, lr=lr, **cfg.get("params", dict())
|
params, lr=lr, **cfg.get("params", dict())
|
||||||
)
|
)
|
||||||
@ -112,7 +114,7 @@ class AutoencodingEngine(AbstractAutoencoder):
|
|||||||
|
|
||||||
self.encoder: torch.nn.Module = instantiate_from_config(encoder_config)
|
self.encoder: torch.nn.Module = instantiate_from_config(encoder_config)
|
||||||
self.decoder: torch.nn.Module = instantiate_from_config(decoder_config)
|
self.decoder: torch.nn.Module = instantiate_from_config(decoder_config)
|
||||||
self.regularization: AbstractRegularizer = instantiate_from_config(
|
self.regularization = instantiate_from_config(
|
||||||
regularizer_config
|
regularizer_config
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
from functools import partial
|
||||||
from typing import Dict, Optional, List
|
from typing import Dict, Optional, List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -6,4 +6,6 @@ lint.select = [
|
|||||||
"S307", # suspicious-eval-usage
|
"S307", # suspicious-eval-usage
|
||||||
"F401", # unused-import
|
"F401", # unused-import
|
||||||
"F841", # unused-local-variable
|
"F841", # unused-local-variable
|
||||||
|
# TODO: Enable F821 after all errors has been fixed. Remaining errors: 7.
|
||||||
|
# "F821", # undefined-name
|
||||||
]
|
]
|
Loading…
Reference in New Issue
Block a user