Fix ruff errors

This commit is contained in:
Yoland Yan 2025-03-02 11:44:41 -08:00
parent 225a196dae
commit 2cd3c8a2fb
2 changed files with 16 additions and 20 deletions

View File

@ -1,10 +1,9 @@
import datetime
import io
import json
import math
import os
import logging
import matplotlib.pyplot as plt
import numpy as np
import safetensors
import torch
@ -17,7 +16,6 @@ import folder_paths
import node_helpers
from comfy.cli_args import args
from comfy.comfy_types.node_typing import IO
from nodes import LoadImage
class TrainSampler(comfy.samplers.Sampler):
@ -30,9 +28,9 @@ class TrainSampler(comfy.samplers.Sampler):
self.optimizer.zero_grad()
noise = model_wrap.inner_model.model_sampling.noise_scaling(sigmas, noise, latent_image, False)
latent = model_wrap.inner_model.model_sampling.noise_scaling(
torch.zeros_like(sigmas),
torch.zeros_like(noise, requires_grad=True),
latent_image,
torch.zeros_like(sigmas),
torch.zeros_like(noise, requires_grad=True),
latent_image,
False
)
@ -42,9 +40,9 @@ class TrainSampler(comfy.samplers.Sampler):
loss = self.loss_fn(denoised, latent.clone())
except RuntimeError as e:
if "does not require grad and does not have a grad_fn" in str(e):
print("WARNING: This is likely due to the model is loaded in inference mode.")
logging.info("WARNING: This is likely due to the model is loaded in inference mode.")
loss.backward()
print(f"Current Training Loss: {loss.item():.6f}")
logging.info(f"Current Training Loss: {loss.item():.6f}")
if self.loss_callback:
self.loss_callback(loss.item())
@ -99,7 +97,7 @@ def load_and_process_images(image_files, input_dir, resize_method="None"):
torch.Tensor: Batch of processed images
"""
if not image_files:
raise ValueError(f"No valid images found in input")
raise ValueError("No valid images found in input")
output_images = []
w, h = None, None
@ -406,9 +404,7 @@ class TrainLoraNode:
)
else:
if existing_lora != "[None]":
print(
f"Warning: No existing weights found for {lora_up_key} or {lora_down_key}"
)
logging.info(f"Warning: No existing weights found for {lora_up_key} or {lora_down_key}")
# Initialize new weights
lora_down = torch.nn.Parameter(
torch.zeros(

View File

@ -9,31 +9,31 @@ def mock_folder_structure():
# Create a nested folder structure
folders = [
"folder1",
os.path.join("folder1", "subfolder1"),
os.path.join("folder1", "subfolder2"),
"folder1/subfolder1",
"folder1/subfolder2",
"folder2",
os.path.join("folder2", "deep"),
os.path.join("folder2", "deep", "nested"),
"folder2/deep",
"folder2/deep/nested",
"empty_folder"
]
# Create the folders
for folder in folders:
os.makedirs(os.path.join(temp_dir, folder))
# Add some files to test they're not included
with open(os.path.join(temp_dir, "root_file.txt"), "w") as f:
f.write("test")
with open(os.path.join(temp_dir, "folder1", "test.txt"), "w") as f:
f.write("test")
set_input_directory(temp_dir)
yield temp_dir
def test_gets_all_folders(mock_folder_structure):
folders = get_input_subfolders()
expected = ["folder1", "folder1/subfolder1", "folder1/subfolder2",
expected = ["folder1", "folder1/subfolder1", "folder1/subfolder2",
"folder2", "folder2/deep", "folder2/deep/nested", "empty_folder"]
assert sorted(folders) == sorted(expected)