First commit: Created outline of red_ribbon and socialtoolkit classes. Architecture
2
.gitignore
vendored
@ -5,8 +5,6 @@ __pycache__/
|
||||
!/input/example.png
|
||||
/models/
|
||||
/temp/
|
||||
/custom_nodes/
|
||||
!custom_nodes/example_node.py.example
|
||||
extra_model_paths.yaml
|
||||
/.vs
|
||||
.vscode/
|
||||
|
28
custom_nodes/red_ribbon/__init__.py
Normal file
@ -0,0 +1,28 @@
|
||||
"""
|
||||
Red Ribbon - A collection of custom nodes for ComfyUI
|
||||
"""
|
||||
|
||||
import easy_nodes
|
||||
import os
|
||||
|
||||
# Version information
|
||||
__version__ = "0.1.0"
|
||||
|
||||
# NOTE This only needs to be called once.
|
||||
easy_nodes.initialize_easy_nodes(default_category="Red Ribbon")
|
||||
|
||||
# Import all modules - this must come after calling initialize_easy_nodes
|
||||
from . import main
|
||||
|
||||
# Get the combined node mappings for ComfyUI
|
||||
NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS = easy_nodes.get_node_mappings()
|
||||
|
||||
# Export so that ComfyUI can pick them up.
|
||||
__all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS']
|
||||
|
||||
# Optional: export the node list to a file so that e.g. ComfyUI-Manager can pick it up.
|
||||
easy_nodes.save_node_list(
|
||||
os.path.join(os.path.dirname(__file__), "red_ribbon_node_list.json")
|
||||
)
|
||||
|
||||
print(f"Red Ribbon v{__version__}: Successfully loaded {len(NODE_CLASS_MAPPINGS)} nodes")
|
46
custom_nodes/red_ribbon/choice_lists.py
Normal file
@ -0,0 +1,46 @@
|
||||
|
||||
|
||||
class GetChoices:
|
||||
|
||||
def __init__(self, resources, configs):
|
||||
self.resources = resources
|
||||
self.configs = configs
|
||||
|
||||
def get_choices(self, source: str):
|
||||
"""
|
||||
Get choices of things from certain websites
|
||||
|
||||
"""
|
||||
|
||||
|
||||
|
||||
|
||||
# Create a list of available models for the API
|
||||
AVAILABLE_MODELS: list[str] = []
|
||||
ANTHROPIC_MODELS = [
|
||||
"claude-3-5-sonnet-latest",
|
||||
"claude-3-5-haiku-latest",
|
||||
"claude-3-opus-latest",
|
||||
]
|
||||
OPEN_AI_MODELS = [
|
||||
"gpt-4o",
|
||||
"chatgpt-4o-latest",
|
||||
"gpt-4o-mini",
|
||||
"o1",
|
||||
"o1-mini",
|
||||
"o1-preview",
|
||||
"gpt-4o-realtime-preview",
|
||||
"gpt-4o-mini-realtime-preview",
|
||||
"gpt-4o-audio-preview",
|
||||
]
|
||||
AVAILABLE_MODELS.extend(ANTHROPIC_MODELS)
|
||||
AVAILABLE_MODELS.extend(OPEN_AI_MODELS)
|
||||
|
||||
|
||||
OPEN_AI_EMBEDDING_MODELS = [
|
||||
"text-embedding-3-small",
|
||||
"text-embedding-3-large",
|
||||
"text-embedding-ada-002",
|
||||
]
|
||||
TEXT_EMBEDDING_MODELS = []
|
||||
TEXT_EMBEDDING_MODELS.append(OPEN_AI_EMBEDDING_MODELS)
|
34
custom_nodes/red_ribbon/config_example.py
Normal file
@ -0,0 +1,34 @@
|
||||
#!/usr/bin/env python
|
||||
# Example of using the Configs class
|
||||
|
||||
from configs import Configs
|
||||
|
||||
def main():
|
||||
# Get the configuration instance
|
||||
config = Configs.get()
|
||||
|
||||
# Access configuration values as read-only properties
|
||||
print("API URL:", config.API_URL)
|
||||
print("Debug Mode:", config.DEBUG_MODE)
|
||||
print("Max Batch Size:", config.MAX_BATCH_SIZE)
|
||||
|
||||
# Access nested dictionary values
|
||||
print("\nModel Paths:")
|
||||
for model_type, path in config.MODEL_PATHS.items():
|
||||
print(f" {model_type}: {path}")
|
||||
|
||||
print("\nCustom Settings:")
|
||||
for key, value in config.CUSTOM_SETTINGS.items():
|
||||
print(f" {key}: {value}")
|
||||
|
||||
# Attempt to modify a value (will raise an error due to frozen=True)
|
||||
try:
|
||||
config.API_URL = "http://new-url.com"
|
||||
except Exception as e:
|
||||
print(f"\nAttempted to modify API_URL and got: {type(e).__name__}: {e}")
|
||||
|
||||
# Configuration stays unchanged
|
||||
print("\nAPI URL is still:", config.API_URL)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
105
custom_nodes/red_ribbon/configs.py
Normal file
@ -0,0 +1,105 @@
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
import os
|
||||
import yaml
|
||||
from pydantic import BaseModel, Field
|
||||
from functools import lru_cache
|
||||
|
||||
|
||||
|
||||
class Paths(BaseModel):
|
||||
THIS_FILE = Path(__file__).resolve()
|
||||
RED_RIBBON_DIR = THIS_FILE.parent
|
||||
CUSTOM_NODES_DIR = RED_RIBBON_DIR.parent
|
||||
COMFYUI_DIR = CUSTOM_NODES_DIR.parent
|
||||
LLM_OUTPUTS_DIR = COMFYUI_DIR / "output" / "red_ribbon_outputs"
|
||||
LLM_MODELS_DIR = COMFYUI_DIR / "models" / "llm_models"
|
||||
|
||||
class Config:
|
||||
frozen = True # Make the model immutable (read-only)
|
||||
|
||||
|
||||
class SocialToolkitConfigs(BaseModel):
|
||||
"""Configuration for High Level Architecture workflow"""
|
||||
approved_document_sources: list[str]
|
||||
llm_api_config: dict[str, Any]
|
||||
document_retrieval_threshold: int = 10
|
||||
relevance_threshold: float = 0.7
|
||||
output_format: str = "json"
|
||||
|
||||
codebook: Optional[dict[str, Any]] = None
|
||||
document_retrieval: Optional[dict[str, Any]] = None
|
||||
llm_service: Optional[dict[str, Any]] = None
|
||||
top10_retrieval: Optional[dict[str, Any]] = None
|
||||
relevance_assessment: Optional[dict[str, Any]] = None
|
||||
prompt_decision_tree: Optional[dict[str, Any]] = None
|
||||
|
||||
|
||||
class ConfigsBase(BaseModel):
|
||||
"""Base model for configuration with read-only fields."""
|
||||
|
||||
class Config:
|
||||
frozen = True # Make the model immutable (read-only)
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_config() -> 'Configs':
|
||||
"""
|
||||
Load configuration from YAML files and cache the result.
|
||||
Returns a read-only Configs object.
|
||||
"""
|
||||
base_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# Load main configs
|
||||
config_path = os.path.join(base_dir, "configs.yaml")
|
||||
config_data = {}
|
||||
if os.path.exists(config_path):
|
||||
with open(config_path, 'r') as f:
|
||||
config_data = yaml.safe_load(f) or {}
|
||||
|
||||
# Load private configs (overrides main configs)
|
||||
private_config_path = os.path.join(base_dir, "private_configs.yaml")
|
||||
private_config_data = {}
|
||||
if os.path.exists(private_config_path):
|
||||
with open(private_config_path, 'r') as f:
|
||||
private_config_data = yaml.safe_load(f) or {}
|
||||
|
||||
# Merge configs, with private taking precedence
|
||||
merged_config = {**config_data, **private_config_data}
|
||||
|
||||
return Configs(**merged_config)
|
||||
|
||||
|
||||
class Configs(ConfigsBase):
|
||||
"""
|
||||
Configuration constants loaded from YAML files.
|
||||
All fields are read-only.
|
||||
|
||||
Loads from:
|
||||
- configs.yaml (base configuration)
|
||||
- private_configs.yaml (overrides base configuration)
|
||||
"""
|
||||
# Add your configuration fields here with defaults
|
||||
# Example:
|
||||
API_URL: str = Field("http://localhost:8000", description="API URL")
|
||||
DEBUG_MODE: bool = Field(default=False, description="Enable debug mode")
|
||||
MAX_BATCH_SIZE: int = Field(default=4, description="Maximum batch size")
|
||||
MODEL_PATHS: Dict[str, str] = Field(default_factory=dict, description="Paths to models")
|
||||
CUSTOM_SETTINGS: Dict[str, Any] = Field(default_factory=dict, description="Custom configuration settings")
|
||||
|
||||
_paths: Paths = Field(default_factory=Paths)
|
||||
_socialtoolkit: SocialToolkitConfigs = Field(default_factory=SocialToolkitConfigs)
|
||||
|
||||
# Access the singleton instance through this class method
|
||||
@classmethod
|
||||
def get(cls) -> 'Configs':
|
||||
"""Get the singleton instance of Configs."""
|
||||
return get_config()
|
||||
|
||||
@property
|
||||
def paths(self) -> Paths:
|
||||
return self._paths
|
||||
|
||||
@property
|
||||
def socialtoolkit(self) -> SocialToolkitConfigs:
|
||||
return self._socialtoolkit
|
22
custom_nodes/red_ribbon/configs.yaml
Normal file
@ -0,0 +1,22 @@
|
||||
# Main Configuration File
|
||||
# These are the default settings that can be overridden by private_configs.yaml
|
||||
|
||||
# API Settings
|
||||
API_URL: "http://localhost:8000"
|
||||
DEBUG_MODE: false
|
||||
MAX_BATCH_SIZE: 4
|
||||
|
||||
# Model Paths
|
||||
MODEL_PATHS:
|
||||
stable_diffusion: "models/stable-diffusion"
|
||||
controlnet: "models/controlnet"
|
||||
vae: "models/vae"
|
||||
lora: "models/lora"
|
||||
|
||||
# Custom Settings
|
||||
CUSTOM_SETTINGS:
|
||||
cache_dir: "cache"
|
||||
temp_dir: "temp"
|
||||
max_image_size: 2048
|
||||
default_sampler: "euler_a"
|
||||
default_scheduler: "normal"
|
129
custom_nodes/red_ribbon/main.py
Normal file
@ -0,0 +1,129 @@
|
||||
"""
|
||||
Red Ribbon - Main module for importing and registering all nodes
|
||||
"""
|
||||
from typing import Type
|
||||
|
||||
|
||||
from easy_nodes import (
|
||||
NumberInput,
|
||||
ComfyNode,
|
||||
StringInput,
|
||||
Choice,
|
||||
)
|
||||
|
||||
|
||||
# Import components from subdirectories
|
||||
from .socialtoolkit.socialtoolkit import SocialToolkitAPI
|
||||
from .red_ribbon_core.red_ribbon import RedRibbonAPI
|
||||
from .plug_in_play_transformer.plug_in_play_transformer import TransformerAPI
|
||||
from .utils.utils import UtilsAPI
|
||||
from .configs import Configs
|
||||
from .node_types import register_pydantic_models
|
||||
|
||||
|
||||
modules_to_register = [
|
||||
"red_ribbon",
|
||||
"socialtoolkit",
|
||||
"utils",
|
||||
"plug_in_play_transformer",
|
||||
"configs",
|
||||
]
|
||||
register_pydantic_models(modules_to_register)
|
||||
|
||||
|
||||
class RedRibbonPackage:
|
||||
"""Main interface for the Red Ribbon package"""
|
||||
|
||||
def __init__(self, resources: dict[str, object] = None, configs: Configs = None):
|
||||
"""Initialize the Red Ribbon package components"""
|
||||
self.configs = configs
|
||||
self.resources = resources
|
||||
|
||||
self.social: Type[SocialToolkitAPI] = self.resources.get("social")
|
||||
self.rr: Type[RedRibbonAPI] = self.resources.get("rr")
|
||||
self.trans: Type[TransformerAPI] = self.resources.get("trans")
|
||||
self.utils: Type[UtilsAPI] = self.resources.get("utils")
|
||||
|
||||
def version(self):
|
||||
"""Get the version of the Red Ribbon package"""
|
||||
from . import __version__
|
||||
return __version__
|
||||
|
||||
|
||||
rr_resources = {
|
||||
}
|
||||
social_resources = {
|
||||
}
|
||||
trans_resources = {
|
||||
}
|
||||
utils_resources = {
|
||||
}
|
||||
|
||||
configs = Configs()
|
||||
resources = {
|
||||
"social": SocialToolkitAPI(social_resources, configs),
|
||||
"rr": RedRibbonAPI(rr_resources, configs),
|
||||
"trans": TransformerAPI(trans_resources, configs),
|
||||
"utils": UtilsAPI(utils_resources, configs)
|
||||
}
|
||||
|
||||
# Initialize the Red Ribbon package
|
||||
package = RedRibbonPackage(resources, configs)
|
||||
|
||||
|
||||
@ComfyNode("Socialtoolkit",
|
||||
color="#d30e0e",
|
||||
bg_color="#ff0000",
|
||||
display_name="Rank and Sort Similarity Search Results")
|
||||
def rank_and_sort_similar_search_results(
|
||||
search_results: list,
|
||||
search_query: str,
|
||||
search_type: str,
|
||||
rank_by: str,
|
||||
sort_by: str
|
||||
) -> list:
|
||||
"""
|
||||
Rank and sort similarity search results.
|
||||
"""
|
||||
return package.social.rank_and_sort_similar_search_results(
|
||||
search_results,
|
||||
search_query,
|
||||
search_type,
|
||||
rank_by,
|
||||
sort_by
|
||||
)
|
||||
|
||||
@ComfyNode("Socialtoolkit",
|
||||
color="#d30e0e",
|
||||
bg_color="#ff0000",
|
||||
display_name="Retrieve Documents from Websites")
|
||||
def document_retrieval_from_websites(
|
||||
domain_urls: list[str]
|
||||
) -> tuple['Document', 'Metadata', 'Vectors']:
|
||||
"""
|
||||
Document retrieval from websites.
|
||||
"""
|
||||
resources: dict[str, object],
|
||||
configs: Configs
|
||||
|
||||
socialtoolkit = SocialToolkitAPI(resources, configs)
|
||||
return socialtoolkit.document_retrieval_from_websites(
|
||||
domain_urls
|
||||
)
|
||||
|
||||
|
||||
# Main function that can be called when using this as a script
|
||||
def main():
|
||||
print("Red Ribbon package loaded successfully")
|
||||
package = RedRibbonPackage()
|
||||
print(f"Version: {package.version()}")
|
||||
print("Available components:")
|
||||
print("- SocialToolkit")
|
||||
print("- RedRibbon Core")
|
||||
print("- Plug-in-Play Transformer")
|
||||
print("- Utils")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
74
custom_nodes/red_ribbon/node_types.py
Normal file
@ -0,0 +1,74 @@
|
||||
import importlib
|
||||
import inspect
|
||||
from typing import Type
|
||||
|
||||
from easy_nodes import register_type
|
||||
from easy_nodes.easy_nodes import AnythingVerifier
|
||||
try:
|
||||
from pydantic import BaseModel
|
||||
except ImportError:
|
||||
print("Pydantic not found. Please install it with 'pip install pydantic'")
|
||||
BaseModel = object # Fallback if pydantic isn't installed
|
||||
|
||||
|
||||
def registration_callback(register_these_classes: list[Type[BaseModel]]) -> None:
|
||||
for this_class in register_these_classes:
|
||||
with_its_class_name_in_all_caps: str = this_class.__qualname__.upper()
|
||||
register_type(this_class, with_its_class_name_in_all_caps, verifier=AnythingVerifier())
|
||||
|
||||
|
||||
def register_pydantic_models(
|
||||
module_names: list[str],
|
||||
) -> None:
|
||||
"""
|
||||
Loads Pydantic classes from specified modules and registers them.
|
||||
|
||||
Args:
|
||||
module_names: list of module names to search for Pydantic models
|
||||
registration_callback: Optional function to call for each model (for registration)
|
||||
If None, a dummy registration function will be used
|
||||
|
||||
Returns:
|
||||
Side-effect: registers Pydantic models with EasyNodes.
|
||||
"""
|
||||
models = []
|
||||
for module_name in module_names:
|
||||
try:
|
||||
# Import the module
|
||||
module = importlib.import_module(module_name)
|
||||
|
||||
# Find all Pydantic classes in the module
|
||||
for _, obj in inspect.getmembers(module):
|
||||
# Check if it's a class and a subclass of BaseModel but not BaseModel itself
|
||||
if (inspect.isclass(obj) and
|
||||
issubclass(obj, BaseModel) and
|
||||
obj is not BaseModel):
|
||||
models.append(obj)
|
||||
except ImportError as e:
|
||||
print(f"Error importing module {module_name}: {e}")
|
||||
except Exception as e:
|
||||
print(f"Error processing module {module_name}: {e}")
|
||||
|
||||
# Register the model using the provided callback
|
||||
try:
|
||||
registration_callback(models)
|
||||
except Exception as e:
|
||||
print(f"{type(e)} registering models: {e}")
|
||||
|
||||
return models
|
||||
|
||||
# Example usage:
|
||||
# if __name__ == "__main__":
|
||||
# modules_to_scan = ["your_module.models", "another_module.types"]
|
||||
# models = register_pydantic_models(modules_to_scan)
|
||||
# print(f"Found {len(models)} Pydantic models")
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -0,0 +1,5 @@
|
||||
"""
|
||||
Plug-in-Play Transformer Module for Red Ribbon
|
||||
"""
|
||||
|
||||
|
@ -0,0 +1,22 @@
|
||||
"""
|
||||
Plug-in-Play Transformer - Main entrance file for transformer functionality
|
||||
"""
|
||||
|
||||
from . import PiPTransformerNode
|
||||
|
||||
class TransformerAPI:
|
||||
"""API for accessing Transformer functionality from other modules"""
|
||||
|
||||
def __init__(self, resources, configs):
|
||||
self.configs = configs
|
||||
self.resources = resources
|
||||
|
||||
# Main function that can be called when using this as a script
|
||||
def main():
|
||||
print("Plug-in-Play Transformer module loaded successfully")
|
||||
print("Available tools:")
|
||||
print("- PiPTransformerNode: Node for ComfyUI integration")
|
||||
print("- TransformerAPI: API for programmatic access")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
14
custom_nodes/red_ribbon/private_configs.yaml
Normal file
@ -0,0 +1,14 @@
|
||||
# Private Configuration File
|
||||
# These settings override the ones in configs.yaml
|
||||
# Add this file to .gitignore to keep sensitive information private
|
||||
|
||||
# API Settings
|
||||
API_URL: "http://localhost:8001" # Override the default port
|
||||
|
||||
# Custom private settings
|
||||
DEBUG_MODE: true # Enable debug mode in development environment
|
||||
|
||||
# Custom Settings with overrides
|
||||
CUSTOM_SETTINGS:
|
||||
api_key: "your_secret_api_key_here" # Add private API keys
|
||||
database_url: "postgresql://user:password@localhost:5432/db"
|
40
custom_nodes/red_ribbon/red_ribbon.py
Normal file
@ -0,0 +1,40 @@
|
||||
"""
|
||||
Red Ribbon - Main entrance file for the entire Red Ribbon package
|
||||
"""
|
||||
|
||||
# Import components from subdirectories
|
||||
from .socialtoolkit.socialtoolkit import SocialToolkitAPI
|
||||
from .red_ribbon_core.red_ribbon import RedRibbonAPI
|
||||
from .plug_in_play_transformer.plug_in_play_transformer import TransformerAPI
|
||||
from .utils.utils import UtilsAPI
|
||||
|
||||
|
||||
from .main import package
|
||||
|
||||
|
||||
|
||||
# Main function that can be called when using this as a script
|
||||
def main():
|
||||
print("Red Ribbon package loaded successfully")
|
||||
print(f"Version: {package.version()}")
|
||||
print("Available components:")
|
||||
print("- SocialToolkit")
|
||||
print("- RedRibbon Core")
|
||||
print("- Plug-in-Play Transformer")
|
||||
print("- Utils")
|
||||
while True:
|
||||
choice_was = input("Enter your choice: ")
|
||||
match choice_was:
|
||||
case "SocialToolkit":
|
||||
print("SocialToolkitAPI")
|
||||
case "RedRibbon Core":
|
||||
print("RedRibbonAPI")
|
||||
case "Plug-in-Play Transformer":
|
||||
print("TransformerAPI")
|
||||
case "Utils":
|
||||
print("UtilsAPI")
|
||||
case _:
|
||||
print("Invalid choice. Try again.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
39
custom_nodes/red_ribbon/red_ribbon_core/__init__.py
Normal file
@ -0,0 +1,39 @@
|
||||
"""
|
||||
Red Ribbon Core Module
|
||||
"""
|
||||
|
||||
class RedRibbonNode:
|
||||
"""Main node for Red Ribbon functionality"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
"effect": (["basic", "advanced", "extreme"], {"default": "basic"}),
|
||||
"intensity": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 2.0, "step": 0.01}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "Red Ribbon/Effects"
|
||||
|
||||
def process(self, image, effect, intensity):
|
||||
# Process the image with Red Ribbon effects
|
||||
# In a real implementation, this would apply the selected effect
|
||||
return (image,)
|
||||
|
||||
# Dictionary of nodes to be imported by main.py
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"RedRibbonNode": RedRibbonNode
|
||||
}
|
||||
|
||||
# Add display names for the nodes
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"RedRibbonNode": "Red Ribbon Effect"
|
||||
}
|
||||
|
||||
# Function to be called from main.py
|
||||
def red_ribbon():
|
||||
return NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
|
36
custom_nodes/red_ribbon/red_ribbon_core/red_ribbon.py
Normal file
@ -0,0 +1,36 @@
|
||||
"""
|
||||
Red Ribbon - Main entrance file for core Red Ribbon functionality
|
||||
"""
|
||||
|
||||
from . import RedRibbonNode
|
||||
|
||||
class RedRibbonAPI:
|
||||
"""API for accessing Red Ribbon functionality from other modules"""
|
||||
|
||||
def __init__(self, resources, configs):
|
||||
self.configs = configs
|
||||
self.resources = resources
|
||||
|
||||
def create_text_embedding(self, text):
|
||||
"""Create an embedding for the given text
|
||||
|
||||
Args:
|
||||
text (str): The text to embed
|
||||
|
||||
Returns:
|
||||
list: The embedding vector
|
||||
"""
|
||||
# In a real implementation, this would use an actual embedding model
|
||||
embedding = [ord(char) for char in text]
|
||||
return embedding
|
||||
|
||||
|
||||
# Main function that can be called when using this as a script
|
||||
def main():
|
||||
print("Red Ribbon core module loaded successfully")
|
||||
print("Available tools:")
|
||||
print("- RedRibbonNode: Node for ComfyUI integration")
|
||||
print("- RedRibbonAPI: API for programmatic access")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
37
custom_nodes/red_ribbon/socialtoolkit/__init__.py
Normal file
@ -0,0 +1,37 @@
|
||||
"""
|
||||
Social Toolkit Module for Red Ribbon
|
||||
"""
|
||||
|
||||
class SocialToolkitNode:
|
||||
"""Node for social media integration tools"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"text": ("STRING", {"multiline": True}),
|
||||
"platform": (["twitter", "instagram", "facebook"], {"default": "twitter"}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("STRING",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "Red Ribbon/Social"
|
||||
|
||||
def process(self, text, platform):
|
||||
# Process the text for social media
|
||||
return (f"[{platform.upper()}]: {text}",)
|
||||
|
||||
# Dictionary of nodes to be imported by main.py
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"SocialToolkitNode": SocialToolkitNode
|
||||
}
|
||||
|
||||
# Add display names for the nodes
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"SocialToolkitNode": "Social Media Toolkit"
|
||||
}
|
||||
|
||||
# Function to be called from main.py
|
||||
def socialtoolkit():
|
||||
return NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
|
@ -0,0 +1,156 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
from enum import Enum
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class WebpageType(str, Enum):
|
||||
STATIC = "static"
|
||||
DYNAMIC = "dynamic"
|
||||
|
||||
class DocumentRetrievalConfigs(BaseModel):
|
||||
"""Configuration for Document Retrieval from Websites workflow"""
|
||||
timeout_seconds: int = 30
|
||||
max_retries: int = 3
|
||||
user_agent: str = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36"
|
||||
dynamic_rendering_wait_time: int = 5
|
||||
selenium_enabled: bool = False
|
||||
headers: Dict[str, str] = {}
|
||||
batch_size: int = 10
|
||||
follow_links: bool = False
|
||||
max_depth: int = 1
|
||||
|
||||
|
||||
class DocumentRetrievalFromWebsites:
|
||||
"""
|
||||
Document Retrieval from Websites for data extraction system
|
||||
based on mermaid chart in README.md
|
||||
"""
|
||||
|
||||
def __init__(self, resources: Dict[str, Any], configs: DocumentRetrievalConfigs):
|
||||
"""
|
||||
Initialize with injected dependencies and configuration
|
||||
|
||||
Args:
|
||||
resources: Dictionary of resources including services
|
||||
configs: Configuration for Document Retrieval
|
||||
"""
|
||||
self.resources = resources
|
||||
self.configs = configs
|
||||
|
||||
# Extract needed services from resources
|
||||
self.static_webpage_parser = resources.get("static_webpage_parser")
|
||||
self.dynamic_webpage_parser = resources.get("dynamic_webpage_parser")
|
||||
self.data_extractor = resources.get("data_extractor")
|
||||
self.vector_generator = resources.get("vector_generator")
|
||||
self.metadata_generator = resources.get("metadata_generator")
|
||||
self.document_storage = resources.get("document_storage_service")
|
||||
self.url_path_generator = resources.get("url_path_generator")
|
||||
|
||||
logger.info("DocumentRetrievalFromWebsites initialized with services")
|
||||
|
||||
def execute(self, domain_urls: List[str]) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute the document retrieval flow based on the mermaid chart
|
||||
|
||||
Args:
|
||||
domain_urls: List of domain URLs to retrieve documents from
|
||||
|
||||
Returns:
|
||||
Dictionary containing retrieved documents, metadata, and vectors
|
||||
"""
|
||||
logger.info(f"Starting document retrieval from {len(domain_urls)} domains")
|
||||
|
||||
all_documents = []
|
||||
all_metadata = []
|
||||
all_vectors = []
|
||||
|
||||
for domain_url in domain_urls:
|
||||
# Step 1: Generate URLs from domain URL
|
||||
urls = self._generate_urls(domain_url)
|
||||
|
||||
for url in urls:
|
||||
# Step 2: Determine webpage type and parse accordingly
|
||||
webpage_type = self._determine_webpage_type(url)
|
||||
|
||||
if webpage_type == WebpageType.STATIC:
|
||||
raw_data = self.static_webpage_parser.parse(url)
|
||||
else:
|
||||
raw_data = self.dynamic_webpage_parser.parse(url)
|
||||
|
||||
# Step 3: Extract structured data from raw data
|
||||
raw_strings = self.data_extractor.extract(raw_data)
|
||||
|
||||
# Step 4: Generate documents, vectors, and metadata
|
||||
documents = self._create_documents(raw_strings, url)
|
||||
document_vectors = self.vector_generator.generate(documents)
|
||||
document_metadata = self.metadata_generator.generate(documents, url)
|
||||
|
||||
all_documents.extend(documents)
|
||||
all_vectors.extend(document_vectors)
|
||||
all_metadata.extend(document_metadata)
|
||||
|
||||
# Step 5: Store documents, vectors, and metadata
|
||||
self.document_storage.store(all_documents, all_metadata, all_vectors)
|
||||
|
||||
logger.info(f"Retrieved and stored {len(all_documents)} documents")
|
||||
return {
|
||||
"documents": all_documents,
|
||||
"metadata": all_metadata,
|
||||
"vectors": all_vectors
|
||||
}
|
||||
|
||||
def retrieve_documents(self, domain_urls: List[str]) -> Tuple[List[Any], List[Any], List[Any]]:
|
||||
"""
|
||||
Public method to retrieve documents from websites
|
||||
|
||||
Args:
|
||||
domain_urls: List of domain URLs to retrieve documents from
|
||||
|
||||
Returns:
|
||||
Tuple of (documents, metadata, vectors)
|
||||
"""
|
||||
result = self.control_flow(domain_urls)
|
||||
return (
|
||||
result["documents"],
|
||||
result["metadata"],
|
||||
result["vectors"]
|
||||
)
|
||||
|
||||
def _generate_urls(self, domain_url: str) -> List[str]:
|
||||
"""Generate URLs from domain URL using URL path generator"""
|
||||
return self.url_path_generator.generate(domain_url)
|
||||
|
||||
def _determine_webpage_type(self, url: str) -> WebpageType:
|
||||
"""
|
||||
Determine whether a webpage is static or dynamic
|
||||
|
||||
This is a simple implementation that could be enhanced with
|
||||
more sophisticated detection mechanisms
|
||||
"""
|
||||
# Check URL patterns that typically indicate dynamic content
|
||||
dynamic_indicators = [
|
||||
"#!", "?", "api", "ajax", "load", "spa", "react",
|
||||
"angular", "vue", "dynamic", "js-rendered"
|
||||
]
|
||||
|
||||
for indicator in dynamic_indicators:
|
||||
if indicator in url.lower():
|
||||
return WebpageType.DYNAMIC
|
||||
|
||||
return WebpageType.STATIC
|
||||
|
||||
def _create_documents(self, raw_strings: List[str], url: str) -> List[Any]:
|
||||
"""Create documents from raw strings"""
|
||||
# Implementation would create document objects from raw text content
|
||||
# This is a placeholder implementation
|
||||
documents = []
|
||||
for i, content in enumerate(raw_strings):
|
||||
documents.append({
|
||||
"id": f"{url}_{i}",
|
||||
"content": content,
|
||||
"url": url,
|
||||
"timestamp": self.resources.get("timestamp_service").now()
|
||||
})
|
||||
return documents
|
@ -0,0 +1,492 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
from enum import Enum
|
||||
from uuid import UUID, uuid4
|
||||
from datetime import datetime
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class DocumentStatus(str, Enum):
|
||||
NEW = "new"
|
||||
PROCESSING = "processing"
|
||||
COMPLETE = "complete"
|
||||
ERROR = "error"
|
||||
|
||||
class VersionStatus(str, Enum):
|
||||
DRAFT = "draft"
|
||||
ACTIVE = "active"
|
||||
SUPERSEDED = "superseded"
|
||||
|
||||
class SourceType(str, Enum):
|
||||
PRIMARY = "primary"
|
||||
SECONDARY = "secondary"
|
||||
TERTIARY = "tertiary"
|
||||
|
||||
class DocumentStorageConfigs(BaseModel):
|
||||
"""Configuration for Document Storage"""
|
||||
database_connection_string: str
|
||||
cache_enabled: bool = True
|
||||
cache_ttl_seconds: int = 3600
|
||||
batch_size: int = 100
|
||||
vector_dim: int = 1536 # Common dimension for embeddings like OpenAI's
|
||||
storage_type: str = "sql" # Alternatives: "nosql", "in_memory", etc.
|
||||
|
||||
class DocumentStorage:
|
||||
"""
|
||||
Document Storage system based on mermaid ER diagram in README.md
|
||||
Manages the storage and retrieval of documents, versions, metadata, and vectors
|
||||
"""
|
||||
|
||||
def __init__(self, resources: Dict[str, Any], configs: DocumentStorageConfigs):
|
||||
"""
|
||||
Initialize with injected dependencies and configuration
|
||||
|
||||
Args:
|
||||
resources: Dictionary of resources including storage services
|
||||
configs: Configuration for Document Storage
|
||||
"""
|
||||
self.resources = resources
|
||||
self.configs = configs
|
||||
|
||||
# Extract needed services from resources
|
||||
self.db_service = resources.get("database_service")
|
||||
self.cache_service = resources.get("cache_service")
|
||||
self.vector_store = resources.get("vector_store_service")
|
||||
self.id_generator = resources.get("id_generator_service", self._generate_uuid)
|
||||
|
||||
logger.info("DocumentStorage initialized with services")
|
||||
|
||||
def execute(self, action: str, **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute document storage operations based on the action
|
||||
|
||||
Args:
|
||||
action: Operation to perform (store, retrieve, update, delete)
|
||||
**kwargs: Operation-specific parameters
|
||||
|
||||
Returns:
|
||||
Dictionary containing operation results
|
||||
"""
|
||||
logger.info(f"Starting document storage operation: {action}")
|
||||
|
||||
if action == "store":
|
||||
return self._store_documents(
|
||||
kwargs.get("documents", []),
|
||||
kwargs.get("metadata", []),
|
||||
kwargs.get("vectors", [])
|
||||
)
|
||||
elif action == "retrieve":
|
||||
return self._retrieve_documents(
|
||||
document_ids=kwargs.get("document_ids", []),
|
||||
filters=kwargs.get("filters", {})
|
||||
)
|
||||
elif action == "update":
|
||||
return self._update_documents(
|
||||
kwargs.get("documents", [])
|
||||
)
|
||||
elif action == "delete":
|
||||
return self._delete_documents(
|
||||
kwargs.get("document_ids", [])
|
||||
)
|
||||
elif action == "get_vectors":
|
||||
return self._get_vectors(
|
||||
kwargs.get("document_ids", [])
|
||||
)
|
||||
else:
|
||||
logger.error(f"Unknown action: {action}")
|
||||
raise ValueError(f"Unknown action: {action}")
|
||||
|
||||
def store(self, documents: List[Any], metadata: List[Any], vectors: List[Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Store documents, metadata, and vectors
|
||||
|
||||
Args:
|
||||
documents: Documents to store
|
||||
metadata: Metadata for the documents
|
||||
vectors: Vectors for the documents
|
||||
|
||||
Returns:
|
||||
Dictionary with storage status
|
||||
"""
|
||||
return self.control_flow("store", documents=documents, metadata=metadata, vectors=vectors)
|
||||
|
||||
def get_documents_and_vectors(self, document_ids: List[str] = None,
|
||||
filters: Dict[str, Any] = None) -> Tuple[List[Any], List[Any]]:
|
||||
"""
|
||||
Retrieve documents and their vectors
|
||||
|
||||
Args:
|
||||
document_ids: Optional list of document IDs to retrieve
|
||||
filters: Optional filters to apply
|
||||
|
||||
Returns:
|
||||
Tuple of (documents, vectors)
|
||||
"""
|
||||
result = self.control_flow(
|
||||
"retrieve", document_ids=document_ids, filters=filters
|
||||
)
|
||||
documents = result.get("documents", [])
|
||||
|
||||
vectors_result = self.control_flow(
|
||||
"get_vectors", document_ids=[doc.get("id") for doc in documents]
|
||||
)
|
||||
vectors = vectors_result.get("vectors", [])
|
||||
|
||||
return documents, vectors
|
||||
|
||||
def _store_documents(self, documents: List[Any], metadata: List[Any], vectors: List[Any]) -> Dict[str, Any]:
|
||||
"""Store documents, metadata, and vectors in the database"""
|
||||
try:
|
||||
# 1. Store source information if needed
|
||||
source_ids = self._store_sources(documents)
|
||||
|
||||
# 2. Store documents
|
||||
document_ids = self._store_document_entries(documents, source_ids)
|
||||
|
||||
# 3. Create versions for documents
|
||||
version_ids = self._create_versions(document_ids)
|
||||
|
||||
# 4. Store metadata
|
||||
self._store_metadata(metadata, document_ids)
|
||||
|
||||
# 5. Store content
|
||||
content_ids = self._store_content(documents, version_ids)
|
||||
|
||||
# 6. Create version-content associations
|
||||
self._create_version_content_links(version_ids, content_ids)
|
||||
|
||||
# 7. Store vectors
|
||||
self._store_vectors(vectors, content_ids)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"document_ids": document_ids,
|
||||
"version_ids": version_ids,
|
||||
"content_ids": content_ids
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error storing documents: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
def _retrieve_documents(self, document_ids: List[str] = None,
|
||||
filters: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
"""Retrieve documents from the database"""
|
||||
try:
|
||||
documents = []
|
||||
|
||||
# Use document IDs if provided, otherwise use filters
|
||||
if document_ids:
|
||||
query = f"SELECT * FROM Documents WHERE document_id IN ({','.join(['?']*len(document_ids))})"
|
||||
documents = self.db_service.execute(query, document_ids)
|
||||
elif filters:
|
||||
# Build WHERE clause based on filters
|
||||
where_clauses = []
|
||||
params = []
|
||||
|
||||
for key, value in filters.items():
|
||||
where_clauses.append(f"{key} = ?")
|
||||
params.append(value)
|
||||
|
||||
query = f"SELECT * FROM Documents WHERE {' AND '.join(where_clauses)}"
|
||||
documents = self.db_service.execute(query, params)
|
||||
else:
|
||||
# Retrieve all documents (with limit)
|
||||
query = f"SELECT * FROM Documents LIMIT {self.configs.batch_size}"
|
||||
documents = self.db_service.execute(query)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"documents": documents
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving documents: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
def _update_documents(self, documents: List[Any]) -> Dict[str, Any]:
|
||||
"""Update existing documents"""
|
||||
# Implementation for updating documents
|
||||
return {"success": True}
|
||||
|
||||
def _delete_documents(self, document_ids: List[str]) -> Dict[str, Any]:
|
||||
"""Delete documents by ID"""
|
||||
# Implementation for deleting documents
|
||||
return {"success": True}
|
||||
|
||||
def _get_vectors(self, document_ids: List[str]) -> Dict[str, Any]:
|
||||
"""Get vectors for the specified document IDs"""
|
||||
try:
|
||||
# Get content IDs for the documents
|
||||
content_ids_query = """
|
||||
SELECT c.content_id FROM Contents c
|
||||
JOIN VersionsContents vc ON c.content_id = vc.content_id
|
||||
JOIN Versions v ON vc.version_id = v.version_id
|
||||
JOIN Documents d ON v.document_id = d.document_id
|
||||
WHERE d.document_id IN ({}) AND v.current_version = 1
|
||||
""".format(','.join(['?']*len(document_ids)))
|
||||
|
||||
content_ids_result = self.db_service.execute(content_ids_query, document_ids)
|
||||
content_ids = [r["content_id"] for r in content_ids_result]
|
||||
|
||||
# Get vectors for the content
|
||||
vectors_query = f"""
|
||||
SELECT * FROM Vectors WHERE content_id IN ({','.join(['?']*len(content_ids))})
|
||||
"""
|
||||
vectors = self.db_service.execute(vectors_query, content_ids)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"vectors": vectors
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving vectors: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"vectors": []
|
||||
}
|
||||
|
||||
# Helper methods for database operations
|
||||
def _store_sources(self, documents: List[Any]) -> Dict[str, str]:
|
||||
"""Store sources and return a mapping of URL to source_id"""
|
||||
source_map = {}
|
||||
for doc in documents:
|
||||
url = doc.get("url", "")
|
||||
domain = self._extract_domain(url)
|
||||
|
||||
if domain not in source_map:
|
||||
source_id = self.id_generator()
|
||||
|
||||
# Check if source already exists
|
||||
query = "SELECT id FROM Sources WHERE id = ?"
|
||||
result = self.db_service.execute(query, [domain])
|
||||
|
||||
if not result:
|
||||
# Insert new source
|
||||
insert_query = "INSERT INTO Sources (id) VALUES (?)"
|
||||
self.db_service.execute(insert_query, [domain])
|
||||
|
||||
source_map[domain] = domain # Source ID is the domain
|
||||
|
||||
return source_map
|
||||
|
||||
def _store_document_entries(self, documents: List[Any], source_ids: Dict[str, str]) -> List[str]:
|
||||
"""Store document entries and return document IDs"""
|
||||
document_ids = []
|
||||
|
||||
for doc in documents:
|
||||
url = doc.get("url", "")
|
||||
domain = self._extract_domain(url)
|
||||
source_id = source_ids.get(domain)
|
||||
|
||||
document_id = self.id_generator()
|
||||
document_type = self._determine_document_type(url)
|
||||
|
||||
# Insert document
|
||||
insert_query = """
|
||||
INSERT INTO Documents (
|
||||
document_id, source_id, url, document_type,
|
||||
status, priority
|
||||
) VALUES (?, ?, ?, ?, ?, ?)
|
||||
"""
|
||||
|
||||
params = [
|
||||
document_id,
|
||||
source_id,
|
||||
url,
|
||||
document_type,
|
||||
DocumentStatus.NEW.value,
|
||||
5 # Default priority
|
||||
]
|
||||
|
||||
self.db_service.execute(insert_query, params)
|
||||
document_ids.append(document_id)
|
||||
|
||||
return document_ids
|
||||
|
||||
def _create_versions(self, document_ids: List[str]) -> List[str]:
|
||||
"""Create initial versions for documents"""
|
||||
version_ids = []
|
||||
|
||||
for document_id in document_ids:
|
||||
version_id = self.id_generator()
|
||||
|
||||
# Insert version
|
||||
insert_query = """
|
||||
INSERT INTO Versions (
|
||||
version_id, document_id, current_version,
|
||||
version_number, status, processed_at
|
||||
) VALUES (?, ?, ?, ?, ?, ?)
|
||||
"""
|
||||
|
||||
params = [
|
||||
version_id,
|
||||
document_id,
|
||||
True, # Current version
|
||||
"1.0", # Initial version
|
||||
VersionStatus.ACTIVE.value,
|
||||
datetime.now()
|
||||
]
|
||||
|
||||
self.db_service.execute(insert_query, params)
|
||||
|
||||
# Update document with current version ID
|
||||
update_query = """
|
||||
UPDATE Documents
|
||||
SET current_version_id = ?, status = ?
|
||||
WHERE document_id = ?
|
||||
"""
|
||||
|
||||
update_params = [
|
||||
version_id,
|
||||
DocumentStatus.COMPLETE.value,
|
||||
document_id
|
||||
]
|
||||
|
||||
self.db_service.execute(update_query, update_params)
|
||||
version_ids.append(version_id)
|
||||
|
||||
return version_ids
|
||||
|
||||
def _store_metadata(self, metadata_list: List[Any], document_ids: List[str]) -> None:
|
||||
"""Store metadata for documents"""
|
||||
for i, metadata in enumerate(metadata_list):
|
||||
if i >= len(document_ids):
|
||||
break
|
||||
|
||||
document_id = document_ids[i]
|
||||
metadata_id = self.id_generator()
|
||||
|
||||
# Insert metadata
|
||||
insert_query = """
|
||||
INSERT INTO Metadatas (
|
||||
metadata_id, document_id, other_metadata,
|
||||
created_at, updated_at
|
||||
) VALUES (?, ?, ?, ?, ?)
|
||||
"""
|
||||
|
||||
params = [
|
||||
metadata_id,
|
||||
document_id,
|
||||
metadata.get("metadata", "{}"),
|
||||
datetime.now(),
|
||||
datetime.now()
|
||||
]
|
||||
|
||||
self.db_service.execute(insert_query, params)
|
||||
|
||||
def _store_content(self, documents: List[Any], version_ids: List[str]) -> List[str]:
|
||||
"""Store content for document versions"""
|
||||
content_ids = []
|
||||
|
||||
for i, doc in enumerate(documents):
|
||||
if i >= len(version_ids):
|
||||
break
|
||||
|
||||
version_id = version_ids[i]
|
||||
content_id = self.id_generator()
|
||||
|
||||
# Insert content
|
||||
insert_query = """
|
||||
INSERT INTO Contents (
|
||||
content_id, version_id, raw_content,
|
||||
processed_content, hash
|
||||
) VALUES (?, ?, ?, ?, ?)
|
||||
"""
|
||||
|
||||
content = doc.get("content", "")
|
||||
processed_content = content # In reality, this might go through processing
|
||||
content_hash = self._generate_hash(content)
|
||||
|
||||
params = [
|
||||
content_id,
|
||||
version_id,
|
||||
content,
|
||||
processed_content,
|
||||
content_hash
|
||||
]
|
||||
|
||||
self.db_service.execute(insert_query, params)
|
||||
content_ids.append(content_id)
|
||||
|
||||
return content_ids
|
||||
|
||||
def _create_version_content_links(self, version_ids: List[str], content_ids: List[str]) -> None:
|
||||
"""Create links between versions and content"""
|
||||
for i, version_id in enumerate(version_ids):
|
||||
if i >= len(content_ids):
|
||||
break
|
||||
|
||||
content_id = content_ids[i]
|
||||
|
||||
# Insert version-content link
|
||||
insert_query = """
|
||||
INSERT INTO VersionsContents (
|
||||
version_id, content_id, created_at, source_type
|
||||
) VALUES (?, ?, ?, ?)
|
||||
"""
|
||||
|
||||
params = [
|
||||
version_id,
|
||||
content_id,
|
||||
datetime.now(),
|
||||
SourceType.PRIMARY.value
|
||||
]
|
||||
|
||||
self.db_service.execute(insert_query, params)
|
||||
|
||||
def _store_vectors(self, vectors: List[Any], content_ids: List[str]) -> None:
|
||||
"""Store vectors for content"""
|
||||
for i, vector in enumerate(vectors):
|
||||
if i >= len(content_ids):
|
||||
break
|
||||
|
||||
content_id = content_ids[i]
|
||||
vector_id = self.id_generator()
|
||||
|
||||
# Insert vector
|
||||
insert_query = """
|
||||
INSERT INTO Vectors (
|
||||
vector_id, content_id, vector_embedding, embedding_type
|
||||
) VALUES (?, ?, ?, ?)
|
||||
"""
|
||||
|
||||
params = [
|
||||
vector_id,
|
||||
content_id,
|
||||
vector.get("embedding"),
|
||||
vector.get("embedding_type", "default")
|
||||
]
|
||||
|
||||
self.db_service.execute(insert_query, params)
|
||||
|
||||
# Utility methods
|
||||
def _generate_uuid(self) -> str:
|
||||
"""Generate a UUID string"""
|
||||
return str(uuid4())
|
||||
|
||||
def _extract_domain(self, url: str) -> str:
|
||||
"""Extract domain from URL"""
|
||||
import re
|
||||
match = re.search(r'https?://([^/]+)', url)
|
||||
return match.group(1) if match else url
|
||||
|
||||
def _determine_document_type(self, url: str) -> str:
|
||||
"""Determine document type from URL"""
|
||||
if url.endswith('.pdf'):
|
||||
return 'pdf'
|
||||
elif url.endswith('.doc') or url.endswith('.docx'):
|
||||
return 'word'
|
||||
else:
|
||||
return 'html'
|
||||
|
||||
def _generate_hash(self, content: str) -> str:
|
||||
"""Generate a hash for content"""
|
||||
import hashlib
|
||||
return hashlib.sha256(content.encode()).hexdigest()
|
@ -0,0 +1,122 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import Dict, List, Any, Optional, Type
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from configs import Configs
|
||||
|
||||
|
||||
class SocialtoolkitConfigs(BaseModel):
|
||||
"""Configuration for High Level Architecture workflow"""
|
||||
approved_document_sources: List[str]
|
||||
llm_api_config: Dict[str, Any]
|
||||
document_retrieval_threshold: int = 10
|
||||
relevance_threshold: float = 0.7
|
||||
output_format: str = "json"
|
||||
get_documents_from_web: bool = False
|
||||
|
||||
|
||||
class Socialtoolkit:
|
||||
"""
|
||||
High Level Architecture for document retrieval and data extraction system
|
||||
based on mermaid chart in README.md
|
||||
"""
|
||||
|
||||
def __init__(self, resources: Dict[str, Any], configs: Configs):
|
||||
"""
|
||||
Initialize with injected dependencies and configuration
|
||||
|
||||
Args:
|
||||
resources: Dictionary of resources including services
|
||||
configs: Configuration for High Level Architecture
|
||||
"""
|
||||
self.resources = resources
|
||||
self.configs: SocialtoolkitConfigs = configs.socialtoolkit
|
||||
self.llm_api = self.llm_service(resources, configs)
|
||||
|
||||
# Extract needed services from resources
|
||||
self.document_retrieval = resources.get("document_retrieval_service")
|
||||
self.document_storage = resources.get("document_storage_service")
|
||||
self.llm_service = resources.get("llm_service")
|
||||
self.top10_retrieval = resources.get("top10_retrieval_service")
|
||||
self.relevance_assessment = resources.get("relevance_assessment_service")
|
||||
self.prompt_decision_tree = resources.get("prompt_decision_tree_service")
|
||||
self.variable_codebook = resources.get("variable_codebook_service")
|
||||
|
||||
logger.info("Socialtoolkit initialized with services")
|
||||
|
||||
|
||||
def execute(self, input_data_point: str) -> dict[str, str] | list[dict[str, str]]:
|
||||
"""
|
||||
Execute the control flow based on the mermaid chart
|
||||
|
||||
Args:
|
||||
input_data_point: The question or information request. This can be a single request.
|
||||
|
||||
Returns:
|
||||
Dictionary containing the output data point.
|
||||
If the request was interpreted as having more than one response, a list of dictionaries is returned.
|
||||
"""
|
||||
logger.info(f"Starting high level control flow with input: {input_data_point}")
|
||||
|
||||
if self.configs.approved_document_sources:
|
||||
# Step 1: Get domain URLs from pre-approved sources
|
||||
domain_urls: list[str] = self.document_retrieval.execute(domain_urls)
|
||||
|
||||
# Step 2: Retrieve documents from websites
|
||||
documents, metadata, vectors = self.document_retrieval.execute(domain_urls)
|
||||
documents: list[tuple[str, ...]]
|
||||
metadata: list[dict[str, Any]]
|
||||
vectors: list[dict[str, list[float]]]
|
||||
|
||||
# Step 3: Store documents in document storage
|
||||
storage_successful: bool = self.document_storage.execute(documents, metadata, vectors)
|
||||
if storage_successful:
|
||||
logger.info("Documents stored successfully")
|
||||
else:
|
||||
logger.warning("Failed to store documents")
|
||||
|
||||
# Step 4: Retrieve documents and document vectors
|
||||
stored_docs, stored_vectors = self.document_retrieval.execute(
|
||||
input_data_point,
|
||||
self.llm_service.execute("retrieve_documents")
|
||||
)
|
||||
stored_docs: list[tuple[str, ...]]
|
||||
stored_vectors: list[dict[str, list[float]]]
|
||||
|
||||
# Step 5: Perform top-10 document retrieval
|
||||
potentially_relevant_docs = self.top10_retrieval.execute(
|
||||
input_data_point,
|
||||
stored_docs,
|
||||
stored_vectors
|
||||
)
|
||||
potentially_relevant_docs: list[tuple[str, ...]]
|
||||
|
||||
# Step 6: Get variable definition from codebook
|
||||
prompt_sequence = self.variable_codebook.execute(self.llm_service, input_data_point)
|
||||
|
||||
# Step 7: Perform relevance assessment
|
||||
relevant_documents = self.relevance_assessment.execute(
|
||||
potentially_relevant_docs,
|
||||
prompt_sequence,
|
||||
self.llm_service.execute("relevance_assessment")
|
||||
)
|
||||
|
||||
# Step 8: Execute prompt decision tree
|
||||
output_data_point = self.prompt_decision_tree.execute(
|
||||
relevant_documents,
|
||||
prompt_sequence,
|
||||
self.llm_service.execute("prompt_decision_tree")
|
||||
)
|
||||
|
||||
if output_data_point is None:
|
||||
logger.warning("Failed to execute prompt decision tree")
|
||||
else:
|
||||
logger.info(f"Completed high level control flow with output: {output_data_point}")
|
||||
|
||||
return {"output_data_point": output_data_point}
|
||||
|
||||
def _get_domain_urls(self) -> List[str]:
|
||||
"""Extract domain URLs from pre-approved document sources"""
|
||||
return self.configs.approved_document_sources
|
@ -0,0 +1,27 @@
|
||||
from logging import Logger
|
||||
from typing import Any
|
||||
|
||||
from configs import Configs
|
||||
|
||||
class LLMService:
|
||||
|
||||
def __init__(self, resources: dict[str, Any], configs: Configs):
|
||||
self.resources = resources
|
||||
self.configs = configs
|
||||
self.logger: Logger = resources.get("logger")
|
||||
|
||||
self.llm_model = resources.get("llm_model")
|
||||
self.llm_tokenizer = resources.get("llm_tokenizer")
|
||||
self.llm_vectorizer = resources.get("llm_vectorizer")
|
||||
self.llm_vector_storage = resources.get("llm_vector_storage")
|
||||
|
||||
self.logger.info("LLMService initialized with services")
|
||||
|
||||
def execute(self, command_context: str, *args, **kwargs):
|
||||
"""
|
||||
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
|
@ -0,0 +1,366 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import Dict, List, Any, Optional, Union
|
||||
import logging
|
||||
from enum import Enum
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class PromptDecisionTreeConfigs(BaseModel):
|
||||
"""Configuration for Prompt Decision Tree workflow"""
|
||||
max_tokens_per_prompt: int = 2000
|
||||
max_pages_to_concatenate: int = 10
|
||||
max_iterations: int = 5
|
||||
confidence_threshold: float = 0.7
|
||||
enable_human_review: bool = True # Whether to enable human review for low confidence or errors
|
||||
context_window_size: int = 8000 # Maximum context window size for LLM
|
||||
|
||||
class PromptDecisionTreeNodeType(str, Enum):
|
||||
"""Types of nodes in the prompt decision tree"""
|
||||
QUESTION = "question"
|
||||
DECISION = "decision"
|
||||
FINAL = "final"
|
||||
|
||||
class PromptDecisionTreeEdge(BaseModel):
|
||||
"""Edge in the prompt decision tree"""
|
||||
condition: str
|
||||
next_node_id: str
|
||||
|
||||
class PromptDecisionTreeNode(BaseModel):
|
||||
"""Node in the prompt decision tree"""
|
||||
id: str
|
||||
type: PromptDecisionTreeNodeType
|
||||
prompt: str
|
||||
edges: Optional[List[PromptDecisionTreeEdge]] = None
|
||||
is_final: bool = False
|
||||
|
||||
class PromptDecisionTree:
|
||||
"""
|
||||
Prompt Decision Tree system based on mermaid flowchart in README.md
|
||||
Executes a decision tree of prompts to extract information from documents
|
||||
"""
|
||||
|
||||
def __init__(self, resources: Dict[str, Any], configs: PromptDecisionTreeConfigs):
|
||||
"""
|
||||
Initialize with injected dependencies and configuration
|
||||
|
||||
Args:
|
||||
resources: Dictionary of resources including services
|
||||
configs: Configuration for Prompt Decision Tree
|
||||
"""
|
||||
self.resources = resources
|
||||
self.configs = configs
|
||||
|
||||
# Extract needed services from resources
|
||||
self.variable_codebook = resources.get("variable_codebook_service")
|
||||
self.human_review_service = resources.get("human_review_service")
|
||||
|
||||
logger.info("PromptDecisionTree initialized with services")
|
||||
|
||||
def control_flow(self, relevant_pages: List[Any],
|
||||
prompt_sequence: List[str],
|
||||
llm_api: Any) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute the prompt decision tree flow based on the mermaid flowchart
|
||||
|
||||
Args:
|
||||
relevant_pages: List of relevant document pages
|
||||
prompt_sequence: List of prompts in the decision tree
|
||||
llm_api: LLM API instance
|
||||
|
||||
Returns:
|
||||
Dictionary containing the output data point
|
||||
"""
|
||||
logger.info(f"Starting prompt decision tree with {len(relevant_pages)} pages")
|
||||
|
||||
# Step 1: Concatenate pages
|
||||
concatenated_pages = self._concatenate_pages(relevant_pages)
|
||||
|
||||
# Step 2: Get desired data point codebook entry & prompt sequence
|
||||
# (Already provided as input parameter)
|
||||
|
||||
# Step 3: Execute prompt decision tree
|
||||
result = self._execute_decision_tree(
|
||||
concatenated_pages, prompt_sequence, llm_api
|
||||
)
|
||||
|
||||
# Step 4: Handle errors and unforeseen edgecases if needed
|
||||
if result.get("error") and self.configs.enable_human_review:
|
||||
result = self._request_human_review(result, concatenated_pages)
|
||||
|
||||
logger.info("Completed prompt decision tree execution")
|
||||
return result
|
||||
|
||||
def execute(self, relevant_pages: List[Any], prompt_sequence: List[str],
|
||||
llm_api: Any) -> Any:
|
||||
"""
|
||||
Public method to execute prompt decision tree
|
||||
|
||||
Args:
|
||||
relevant_pages: List of relevant document pages
|
||||
prompt_sequence: List of prompts in the decision tree
|
||||
llm_api: LLM API instance
|
||||
|
||||
Returns:
|
||||
Output data point
|
||||
"""
|
||||
result = self.control_flow(relevant_pages, prompt_sequence, llm_api)
|
||||
return result.get("output_data_point", "")
|
||||
|
||||
def _concatenate_pages(self, pages: List[Any]) -> str:
|
||||
"""
|
||||
Concatenate pages into a single document
|
||||
|
||||
Args:
|
||||
pages: List of pages to concatenate
|
||||
|
||||
Returns:
|
||||
Concatenated document text
|
||||
"""
|
||||
# Limit number of pages to avoid context window issues
|
||||
pages_to_use = pages[:self.configs.max_pages_to_concatenate]
|
||||
|
||||
concatenated_text = ""
|
||||
|
||||
for i, page in enumerate(pages_to_use):
|
||||
content = page.get("content", "")
|
||||
title = page.get("title", f"Document {i+1}")
|
||||
url = page.get("url", "")
|
||||
|
||||
page_text = f"""
|
||||
--- DOCUMENT {i+1}: {title} ---
|
||||
Source: {url}
|
||||
|
||||
{content}
|
||||
|
||||
"""
|
||||
concatenated_text += page_text
|
||||
|
||||
return concatenated_text
|
||||
|
||||
def _execute_decision_tree(self, document_text: str,
|
||||
prompt_sequence: List[str],
|
||||
llm_api: Any) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute the prompt decision tree
|
||||
|
||||
Args:
|
||||
document_text: Concatenated document text
|
||||
prompt_sequence: List of prompts in the decision tree
|
||||
llm_api: LLM API instance
|
||||
|
||||
Returns:
|
||||
Dictionary containing the execution result
|
||||
"""
|
||||
# Create a simplified decision tree from the prompt sequence
|
||||
decision_tree = self._create_decision_tree(prompt_sequence)
|
||||
|
||||
try:
|
||||
# Start with the first node
|
||||
current_node = decision_tree[0]
|
||||
iteration = 0
|
||||
responses = []
|
||||
|
||||
# Follow the decision tree until a final node is reached or max iterations is exceeded
|
||||
while not current_node.is_final and iteration < self.configs.max_iterations:
|
||||
# Generate prompt for the current node
|
||||
prompt = self._generate_node_prompt(current_node, document_text)
|
||||
|
||||
# Get response from LLM
|
||||
llm_response = llm_api.generate(prompt, max_tokens=self.configs.max_tokens_per_prompt)
|
||||
responses.append({
|
||||
"node_id": current_node.id,
|
||||
"prompt": prompt,
|
||||
"response": llm_response
|
||||
})
|
||||
|
||||
# Determine next node based on response
|
||||
if current_node.edges:
|
||||
next_node_id = self._determine_next_node(llm_response, current_node.edges)
|
||||
current_node = next(
|
||||
(node for node in decision_tree if node.id == next_node_id),
|
||||
decision_tree[-1] # Default to the last node if not found
|
||||
)
|
||||
else:
|
||||
# No edges, move to the next node in sequence
|
||||
node_index = decision_tree.index(current_node)
|
||||
if node_index + 1 < len(decision_tree):
|
||||
current_node = decision_tree[node_index + 1]
|
||||
else:
|
||||
# End of sequence, mark as final
|
||||
current_node.is_final = True
|
||||
|
||||
iteration += 1
|
||||
|
||||
# Process the final response
|
||||
final_response = responses[-1]["response"] if responses else ""
|
||||
output_data_point = self._extract_output_data_point(final_response)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"output_data_point": output_data_point,
|
||||
"responses": responses,
|
||||
"iterations": iteration
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing decision tree: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"output_data_point": ""
|
||||
}
|
||||
|
||||
def _create_decision_tree(self, prompt_sequence: List[str]) -> List[PromptDecisionTreeNode]:
|
||||
"""
|
||||
Create a decision tree from a prompt sequence
|
||||
|
||||
This is a simplified implementation that creates a linear sequence of nodes.
|
||||
In a real system, this would create a proper tree structure with branches.
|
||||
|
||||
Args:
|
||||
prompt_sequence: List of prompts
|
||||
|
||||
Returns:
|
||||
List of nodes in the decision tree
|
||||
"""
|
||||
nodes = []
|
||||
|
||||
for i, prompt in enumerate(prompt_sequence):
|
||||
# Create a node for each prompt
|
||||
node = PromptDecisionTreeNode(
|
||||
id=f"node_{i}",
|
||||
type=PromptDecisionTreeNodeType.QUESTION,
|
||||
prompt=prompt,
|
||||
is_final=(i == len(prompt_sequence) - 1) # Last node is final
|
||||
)
|
||||
|
||||
# Add edges if not the last node
|
||||
if i < len(prompt_sequence) - 1:
|
||||
node.edges = [
|
||||
PromptDecisionTreeEdge(
|
||||
condition="default",
|
||||
next_node_id=f"node_{i+1}"
|
||||
)
|
||||
]
|
||||
|
||||
nodes.append(node)
|
||||
|
||||
return nodes
|
||||
|
||||
def _generate_node_prompt(self, node: PromptDecisionTreeNode, document_text: str) -> str:
|
||||
"""
|
||||
Generate a prompt for a node in the decision tree
|
||||
|
||||
Args:
|
||||
node: Current node in the decision tree
|
||||
document_text: Document text
|
||||
|
||||
Returns:
|
||||
Prompt for the node
|
||||
"""
|
||||
# Truncate document text if too long
|
||||
max_doc_length = self.configs.context_window_size - 500 # Reserve space for instructions
|
||||
if len(document_text) > max_doc_length:
|
||||
document_text = document_text[:max_doc_length] + "..."
|
||||
|
||||
prompt = f"""
|
||||
You are an expert tax researcher assisting with data extraction from official documents.
|
||||
Please carefully analyze the following documents to answer this specific question:
|
||||
|
||||
QUESTION: {node.prompt}
|
||||
|
||||
DOCUMENTS:
|
||||
{document_text}
|
||||
|
||||
Based solely on the information provided in these documents, please answer the question above.
|
||||
If the answer is explicitly stated in the documents, provide the exact information along with its source.
|
||||
If the answer requires interpretation, explain your reasoning clearly.
|
||||
If the information is not available in the documents, respond with "Information not available in the provided documents."
|
||||
|
||||
Your answer should be concise, factual, and directly address the question.
|
||||
"""
|
||||
return prompt
|
||||
|
||||
def _determine_next_node(self, response: str, edges: List[PromptDecisionTreeEdge]) -> str:
|
||||
"""
|
||||
Determine the next node based on the response
|
||||
|
||||
This is a simplified implementation that just follows the default edge.
|
||||
In a real system, this would analyze the response to determine the path.
|
||||
|
||||
Args:
|
||||
response: LLM response
|
||||
edges: List of edges from the current node
|
||||
|
||||
Returns:
|
||||
ID of the next node
|
||||
"""
|
||||
# In this simplified version, just follow the first edge
|
||||
if edges:
|
||||
return edges[0].next_node_id
|
||||
return ""
|
||||
|
||||
def _extract_output_data_point(self, response: str) -> str:
|
||||
"""
|
||||
Extract the output data point from the final response
|
||||
|
||||
Args:
|
||||
response: Final LLM response
|
||||
|
||||
Returns:
|
||||
Extracted output data point
|
||||
"""
|
||||
# Look for patterns like "X%" or "X percent"
|
||||
import re
|
||||
|
||||
# Try to find percentage patterns
|
||||
percentage_match = re.search(r'(\d+(?:\.\d+)?)\s*%', response)
|
||||
if percentage_match:
|
||||
return percentage_match.group(0)
|
||||
|
||||
percentage_word_match = re.search(r'(\d+(?:\.\d+)?)\s+percent', response, re.IGNORECASE)
|
||||
if percentage_word_match:
|
||||
value = percentage_word_match.group(1)
|
||||
return f"{value}%"
|
||||
|
||||
# Look for specific statements about rates
|
||||
rate_match = re.search(r'rate\s+is\s+(\d+(?:\.\d+)?)', response, re.IGNORECASE)
|
||||
if rate_match:
|
||||
value = rate_match.group(1)
|
||||
return f"{value}%"
|
||||
|
||||
# If no specific patterns are found, return a cleaned up version of the response
|
||||
# Limit to 100 characters for brevity
|
||||
cleaned_response = response.strip()
|
||||
if len(cleaned_response) > 100:
|
||||
cleaned_response = cleaned_response[:97] + "..."
|
||||
|
||||
return cleaned_response
|
||||
|
||||
def _request_human_review(self, result: Dict[str, Any], document_text: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Request human review for errors or low confidence results
|
||||
|
||||
Args:
|
||||
result: Result from decision tree execution
|
||||
document_text: Document text
|
||||
|
||||
Returns:
|
||||
Updated result after human review
|
||||
"""
|
||||
if self.human_review_service:
|
||||
review_request = {
|
||||
"error": result.get("error"),
|
||||
"document_text": document_text,
|
||||
"responses": result.get("responses", [])
|
||||
}
|
||||
|
||||
human_review_result = self.human_review_service.review(review_request)
|
||||
|
||||
if human_review_result.get("success"):
|
||||
result["output_data_point"] = human_review_result.get("output_data_point", "")
|
||||
result["human_reviewed"] = True
|
||||
result["success"] = True
|
||||
result.pop("error", None)
|
||||
|
||||
return result
|
@ -0,0 +1,452 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import Dict, List, Any, Optional
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class RelevanceAssessmentConfigs(BaseModel):
|
||||
"""Configuration for Relevance Assessment workflow"""
|
||||
criteria_threshold: float = 0.7 # Minimum relevance score threshold
|
||||
max_retries: int = 3 # Maximum number of retry attempts for LLM API calls
|
||||
max_citation_length: int = 500 # Maximum length of text citations
|
||||
use_hallucination_filter: bool = True # Whether to filter for hallucinations
|
||||
|
||||
class RelevanceAssessment:
|
||||
"""
|
||||
Relevance Assessment system based on mermaid flowchart in README.md
|
||||
Evaluates document relevance using LLM assessments
|
||||
"""
|
||||
|
||||
def __init__(self, resources: Dict[str, Any], configs: RelevanceAssessmentConfigs):
|
||||
"""
|
||||
Initialize with injected dependencies and configuration
|
||||
|
||||
Args:
|
||||
resources: Dictionary of resources including services
|
||||
configs: Configuration for Relevance Assessment
|
||||
"""
|
||||
self.resources = resources
|
||||
self.configs = configs
|
||||
|
||||
# Extract needed services from resources
|
||||
self.variable_codebook = resources.get("variable_codebook_service")
|
||||
self.top10_retrieval = resources.get("top10_retrieval_service")
|
||||
self.cited_page_extractor = resources.get("cited_page_extractor_service")
|
||||
self.prompt_decision_tree = resources.get("prompt_decision_tree_service")
|
||||
|
||||
logger.info("RelevanceAssessment initialized with services")
|
||||
|
||||
def control_flow(self, potentially_relevant_docs: List[Any],
|
||||
variable_definition: Dict[str, Any],
|
||||
llm_api: Any) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute the relevance assessment flow based on the mermaid flowchart
|
||||
|
||||
Args:
|
||||
potentially_relevant_docs: List of potentially relevant documents
|
||||
variable_definition: Variable definition and description
|
||||
llm_api: LLM API instance
|
||||
|
||||
Returns:
|
||||
Dictionary containing relevant documents and page numbers
|
||||
"""
|
||||
logger.info(f"Starting relevance assessment for {len(potentially_relevant_docs)} documents")
|
||||
|
||||
# Step 1: Assess document relevance
|
||||
assessment_results = self._assess_document_relevance(
|
||||
potentially_relevant_docs,
|
||||
variable_definition,
|
||||
llm_api
|
||||
)
|
||||
|
||||
# Step 2: Filter for hallucinations if configured
|
||||
if self.configs.use_hallucination_filter:
|
||||
assessment_results = self._filter_hallucinations(assessment_results, llm_api)
|
||||
|
||||
# Step 3: Score relevance
|
||||
relevance_scores = self._score_relevance(assessment_results, potentially_relevant_docs)
|
||||
|
||||
# Step 4: Apply threshold to separate relevant from irrelevant
|
||||
relevant_pages, discarded_pages = self._apply_threshold(relevance_scores)
|
||||
|
||||
# Step 5: Extract page numbers
|
||||
page_numbers = self._extract_page_numbers(relevant_pages)
|
||||
|
||||
# Step 6: Extract cited pages
|
||||
relevant_pages_content = self._extract_cited_pages(
|
||||
potentially_relevant_docs, page_numbers
|
||||
)
|
||||
|
||||
logger.info(f"Completed relevance assessment: {len(relevant_pages_content)} relevant pages")
|
||||
return {
|
||||
"relevant_pages": relevant_pages_content,
|
||||
"relevant_doc_ids": [page["doc_id"] for page in relevant_pages],
|
||||
"page_numbers": page_numbers,
|
||||
"relevance_scores": relevance_scores
|
||||
}
|
||||
|
||||
def assess(self, potentially_relevant_docs: List[Any],
|
||||
prompt_sequence: List[str], llm_api: Any) -> List[Any]:
|
||||
"""
|
||||
Public method to assess document relevance
|
||||
|
||||
Args:
|
||||
potentially_relevant_docs: List of potentially relevant documents
|
||||
prompt_sequence: List of prompts to use for assessment
|
||||
llm_api: LLM API instance
|
||||
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
# Get variable definition from prompt sequence
|
||||
variable_definition = {
|
||||
"prompt_sequence": prompt_sequence,
|
||||
"description": "Tax information for business operations" # Default description if not available
|
||||
}
|
||||
|
||||
result = self.control_flow(potentially_relevant_docs, variable_definition, llm_api)
|
||||
return result["relevant_pages"]
|
||||
|
||||
def _assess_document_relevance(self, docs: List[Any],
|
||||
variable_definition: Dict[str, Any],
|
||||
llm_api: Any) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Assess document relevance using LLM
|
||||
|
||||
Args:
|
||||
docs: List of documents to assess
|
||||
variable_definition: Variable definition and description
|
||||
llm_api: LLM API instance
|
||||
|
||||
Returns:
|
||||
List of assessment results
|
||||
"""
|
||||
assessment_results = []
|
||||
|
||||
for doc in docs:
|
||||
# Create assessment prompt
|
||||
assessment_prompt = self._create_assessment_prompt(doc, variable_definition)
|
||||
|
||||
# Get LLM assessment
|
||||
try:
|
||||
llm_response = llm_api.generate(assessment_prompt, max_tokens=1000)
|
||||
|
||||
# Parse assessment results
|
||||
assessment = self._parse_assessment(llm_response, doc)
|
||||
assessment_results.append(assessment)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error assessing document {doc.get('id')}: {e}")
|
||||
# Add failed assessment
|
||||
assessment_results.append({
|
||||
"doc_id": doc.get("id"),
|
||||
"relevant": False,
|
||||
"confidence": 0.0,
|
||||
"citation": "",
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
return assessment_results
|
||||
|
||||
def _filter_hallucinations(self, assessments: List[Dict[str, Any]],
|
||||
llm_api: Any) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Filter for hallucinations in LLM assessments
|
||||
|
||||
Args:
|
||||
assessments: List of assessment results
|
||||
llm_api: LLM API instance
|
||||
|
||||
Returns:
|
||||
Filtered list of assessment results
|
||||
"""
|
||||
filtered_assessments = []
|
||||
|
||||
for assessment in assessments:
|
||||
# Skip already irrelevant assessments
|
||||
if not assessment.get("relevant", False):
|
||||
filtered_assessments.append(assessment)
|
||||
continue
|
||||
|
||||
# Create hallucination check prompt
|
||||
hallucination_prompt = self._create_hallucination_prompt(assessment)
|
||||
|
||||
try:
|
||||
# Get LLM hallucination check
|
||||
hallucination_response = llm_api.generate(hallucination_prompt, max_tokens=500)
|
||||
|
||||
# Parse hallucination check
|
||||
is_hallucination = self._parse_hallucination_check(hallucination_response)
|
||||
|
||||
if is_hallucination:
|
||||
# Downgrade relevance for hallucinations
|
||||
assessment["relevant"] = False
|
||||
assessment["confidence"] = 0.0
|
||||
assessment["hallucination"] = True
|
||||
|
||||
filtered_assessments.append(assessment)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking hallucination for document {assessment.get('doc_id')}: {e}")
|
||||
# Keep original assessment in case of error
|
||||
filtered_assessments.append(assessment)
|
||||
|
||||
return filtered_assessments
|
||||
|
||||
def _score_relevance(self, assessments: List[Dict[str, Any]],
|
||||
docs: List[Any]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Score relevance based on LLM assessments
|
||||
|
||||
Args:
|
||||
assessments: List of assessment results
|
||||
docs: List of original documents
|
||||
|
||||
Returns:
|
||||
List of documents with relevance scores
|
||||
"""
|
||||
# Create a dict mapping document ID to original document
|
||||
doc_map = {doc.get("id"): doc for doc in docs}
|
||||
|
||||
relevance_scores = []
|
||||
|
||||
for assessment in assessments:
|
||||
doc_id = assessment.get("doc_id")
|
||||
doc = doc_map.get(doc_id)
|
||||
|
||||
if not doc:
|
||||
logger.warning(f"Document not found for ID: {doc_id}")
|
||||
continue
|
||||
|
||||
# Calculate relevance score based on LLM confidence
|
||||
relevance_score = {
|
||||
"doc_id": doc_id,
|
||||
"page_number": assessment.get("page_number", 1), # Default to page 1 if not specified
|
||||
"score": assessment.get("confidence", 0.0),
|
||||
"relevant": assessment.get("relevant", False),
|
||||
"citation": assessment.get("citation", ""),
|
||||
"content": doc.get("content", "")
|
||||
}
|
||||
|
||||
relevance_scores.append(relevance_score)
|
||||
|
||||
return relevance_scores
|
||||
|
||||
def _apply_threshold(self, relevance_scores: List[Dict[str, Any]]) -> tuple:
|
||||
"""
|
||||
Apply threshold to relevance scores
|
||||
|
||||
Args:
|
||||
relevance_scores: List of documents with relevance scores
|
||||
|
||||
Returns:
|
||||
Tuple of (relevant_pages, discarded_pages)
|
||||
"""
|
||||
relevant_pages = []
|
||||
discarded_pages = []
|
||||
|
||||
for score in relevance_scores:
|
||||
if score.get("score", 0.0) >= self.configs.criteria_threshold:
|
||||
relevant_pages.append(score)
|
||||
else:
|
||||
discarded_pages.append(score)
|
||||
|
||||
return relevant_pages, discarded_pages
|
||||
|
||||
def _extract_page_numbers(self, relevant_pages: List[Dict[str, Any]]) -> Dict[str, List[int]]:
|
||||
"""
|
||||
Extract page numbers from relevant pages
|
||||
|
||||
Args:
|
||||
relevant_pages: List of relevant pages
|
||||
|
||||
Returns:
|
||||
Dictionary mapping document IDs to lists of page numbers
|
||||
"""
|
||||
page_numbers = {}
|
||||
|
||||
for page in relevant_pages:
|
||||
doc_id = page.get("doc_id")
|
||||
page_number = page.get("page_number", 1) # Default to page 1 if not specified
|
||||
|
||||
if doc_id not in page_numbers:
|
||||
page_numbers[doc_id] = []
|
||||
|
||||
page_numbers[doc_id].append(page_number)
|
||||
|
||||
return page_numbers
|
||||
|
||||
def _extract_cited_pages(self, docs: List[Any],
|
||||
page_numbers: Dict[str, List[int]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Extract cited pages from documents
|
||||
|
||||
Args:
|
||||
docs: List of original documents
|
||||
page_numbers: Dictionary mapping document IDs to lists of page numbers
|
||||
|
||||
Returns:
|
||||
List of relevant page contents
|
||||
"""
|
||||
# If cited page extractor service is available, use it
|
||||
if self.cited_page_extractor:
|
||||
return self.cited_page_extractor.extract(docs, page_numbers)
|
||||
|
||||
# Fallback implementation
|
||||
cited_pages = []
|
||||
|
||||
# Create a dict mapping document ID to original document
|
||||
doc_map = {doc.get("id"): doc for doc in docs}
|
||||
|
||||
for doc_id, page_nums in page_numbers.items():
|
||||
doc = doc_map.get(doc_id)
|
||||
|
||||
if not doc:
|
||||
logger.warning(f"Document not found for ID: {doc_id}")
|
||||
continue
|
||||
|
||||
# Extract content for each page
|
||||
for page_num in page_nums:
|
||||
# In this simplified implementation, we assume the whole document is the content
|
||||
# In a real system, this would extract specific pages from multi-page documents
|
||||
cited_pages.append({
|
||||
"doc_id": doc_id,
|
||||
"page_number": page_num,
|
||||
"content": doc.get("content", ""),
|
||||
"title": doc.get("title", ""),
|
||||
"url": doc.get("url", "")
|
||||
})
|
||||
|
||||
return cited_pages
|
||||
|
||||
def _create_assessment_prompt(self, doc: Dict[str, Any],
|
||||
variable_definition: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Create assessment prompt for document relevance
|
||||
|
||||
Args:
|
||||
doc: Document to assess
|
||||
variable_definition: Variable definition and description
|
||||
|
||||
Returns:
|
||||
Assessment prompt
|
||||
"""
|
||||
# Get document content
|
||||
content = doc.get("content", "")[:5000] # Limit content length
|
||||
|
||||
# Get variable information
|
||||
description = variable_definition.get("description", "")
|
||||
prompt_sequence = variable_definition.get("prompt_sequence", [])
|
||||
|
||||
# Create assessment prompt
|
||||
prompt = f"""
|
||||
You are a document relevance assessor. Your task is to determine if the following document is relevant to the given information need.
|
||||
|
||||
Information Need: {description}
|
||||
|
||||
Key Questions:
|
||||
{chr(10).join([f"- {p}" for p in prompt_sequence])}
|
||||
|
||||
Document Content:
|
||||
{content}
|
||||
|
||||
Please assess the document's relevance to the information need based on the following criteria:
|
||||
1. Does the document contain information directly related to the information need?
|
||||
2. Does the document provide sufficient detail to answer at least one of the key questions?
|
||||
3. Is the document from a credible source?
|
||||
|
||||
Provide your assessment in the following format:
|
||||
RELEVANT: [Yes/No]
|
||||
CONFIDENCE: [0.0-1.0]
|
||||
CITATION: [Most relevant text snippet from the document that supports your assessment]
|
||||
REASONING: [Brief explanation for your assessment]
|
||||
"""
|
||||
return prompt
|
||||
|
||||
def _parse_assessment(self, llm_response: str, doc: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Parse LLM assessment response
|
||||
|
||||
Args:
|
||||
llm_response: LLM response text
|
||||
doc: Original document
|
||||
|
||||
Returns:
|
||||
Parsed assessment
|
||||
"""
|
||||
assessment = {
|
||||
"doc_id": doc.get("id"),
|
||||
"relevant": False,
|
||||
"confidence": 0.0,
|
||||
"citation": "",
|
||||
"reasoning": ""
|
||||
}
|
||||
|
||||
try:
|
||||
# Parse relevant
|
||||
if "RELEVANT: Yes" in llm_response:
|
||||
assessment["relevant"] = True
|
||||
|
||||
# Parse confidence
|
||||
confidence_match = re.search(r"CONFIDENCE: (0\.\d+|1\.0)", llm_response)
|
||||
if confidence_match:
|
||||
assessment["confidence"] = float(confidence_match.group(1))
|
||||
|
||||
# Parse citation
|
||||
citation_match = re.search(r"CITATION: (.*?)(?=REASONING:|$)", llm_response, re.DOTALL)
|
||||
if citation_match:
|
||||
citation = citation_match.group(1).strip()
|
||||
# Truncate if necessary
|
||||
if len(citation) > self.configs.max_citation_length:
|
||||
citation = citation[:self.configs.max_citation_length] + "..."
|
||||
assessment["citation"] = citation
|
||||
|
||||
# Parse reasoning
|
||||
reasoning_match = re.search(r"REASONING: (.*?)$", llm_response, re.DOTALL)
|
||||
if reasoning_match:
|
||||
assessment["reasoning"] = reasoning_match.group(1).strip()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing assessment: {e}")
|
||||
|
||||
return assessment
|
||||
|
||||
def _create_hallucination_prompt(self, assessment: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Create prompt to check for hallucinations
|
||||
|
||||
Args:
|
||||
assessment: Document assessment
|
||||
|
||||
Returns:
|
||||
Hallucination check prompt
|
||||
"""
|
||||
citation = assessment.get("citation", "")
|
||||
|
||||
prompt = f"""
|
||||
You are a fact-checking assistant. Your task is to analyze the following excerpt and determine if it directly addresses tax rates, specific tax information, or tax regulations.
|
||||
|
||||
Text excerpt:
|
||||
{citation}
|
||||
|
||||
Please analyze this text and determine if it contains SPECIFIC information about tax rates, tax percentages, or tax regulations.
|
||||
Answer with "HALLUCINATION: Yes" if the text does NOT contain specific tax information.
|
||||
Answer with "HALLUCINATION: No" if the text DOES contain specific tax information.
|
||||
|
||||
Provide a brief explanation for your decision.
|
||||
"""
|
||||
return prompt
|
||||
|
||||
def _parse_hallucination_check(self, hallucination_response: str) -> bool:
|
||||
"""
|
||||
Parse hallucination check response
|
||||
|
||||
Args:
|
||||
hallucination_response: LLM response text
|
||||
|
||||
Returns:
|
||||
True if hallucination detected, False otherwise
|
||||
"""
|
||||
return "HALLUCINATION: Yes" in hallucination_response
|
||||
|
||||
import re # Added for regex pattern matching in parsing
|
@ -0,0 +1,295 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
import logging
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class Top10DocumentRetrievalConfigs(BaseModel):
|
||||
"""Configuration for Top-10 Document Retrieval workflow"""
|
||||
retrieval_count: int = 10 # Number of documents to retrieve
|
||||
similarity_threshold: float = 0.6 # Minimum similarity score
|
||||
ranking_method: str = "cosine_similarity" # Options: cosine_similarity, dot_product, euclidean
|
||||
use_filter: bool = False # Whether to filter results
|
||||
filter_criteria: Dict[str, Any] = {}
|
||||
use_reranking: bool = False # Whether to use reranking
|
||||
|
||||
class Top10DocumentRetrieval:
|
||||
"""
|
||||
Top-10 Document Retrieval system based on mermaid chart in README.md
|
||||
Performs vector search to find the most relevant documents
|
||||
"""
|
||||
|
||||
def __init__(self, resources: Dict[str, Any], configs: Top10DocumentRetrievalConfigs):
|
||||
"""
|
||||
Initialize with injected dependencies and configuration
|
||||
|
||||
Args:
|
||||
resources: Dictionary of resources including search services
|
||||
configs: Configuration for Top-10 Document Retrieval
|
||||
"""
|
||||
self.resources = resources
|
||||
self.configs = configs
|
||||
|
||||
# Extract needed services from resources
|
||||
self.encoder_service = resources.get("encoder_service")
|
||||
self.similarity_search_service = resources.get("similarity_search_service")
|
||||
self.document_storage = resources.get("document_storage_service")
|
||||
|
||||
logger.info("Top10DocumentRetrieval initialized with services")
|
||||
|
||||
def execute(self,
|
||||
input_data_point: str,
|
||||
documents: List[Any] = None,
|
||||
document_vectors: List[Any] = None
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Execute the document retrieval flow based on the mermaid chart
|
||||
|
||||
Args:
|
||||
input_data_point: The query or information request
|
||||
documents: Optional list of documents to search
|
||||
document_vectors: Optional list of document vectors to search
|
||||
|
||||
Returns:
|
||||
Dictionary of documents containing potentially relevant documents, along with potentially relevant metadata.
|
||||
"""
|
||||
logger.info(f"Starting top-10 document retrieval for: {input_data_point}")
|
||||
|
||||
# Step 1: Encode the query
|
||||
encoded_query = self._encode_query(input_data_point)
|
||||
|
||||
# Step 2: Get vector embeddings and document IDs from storage if not provided
|
||||
if documents is None or document_vectors is None:
|
||||
documents, document_vectors = self._get_documents_and_vectors()
|
||||
|
||||
# Step 3: Perform similarity search
|
||||
similarity_scores, doc_ids = self._similarity_search(
|
||||
encoded_query,
|
||||
document_vectors,
|
||||
[doc.get("id") for doc in documents]
|
||||
)
|
||||
|
||||
# Step 4: Rank and sort results
|
||||
ranked_results = self._rank_and_sort_results(similarity_scores, doc_ids)
|
||||
|
||||
# Step 5: Filter to top-N results
|
||||
top_doc_ids = self._filter_to_top_n(ranked_results)
|
||||
|
||||
# Step 6: Retrieve potentially relevant documents
|
||||
potentially_relevant_docs = self._retrieve_relevant_documents(documents, top_doc_ids)
|
||||
|
||||
logger.info(f"Retrieved {len(potentially_relevant_docs)} potentially relevant documents")
|
||||
return {
|
||||
"relevant_documents": potentially_relevant_docs,
|
||||
"scores": {doc_id: score for doc_id, score in ranked_results},
|
||||
"top_doc_ids": top_doc_ids
|
||||
}
|
||||
|
||||
def retrieve_top_documents(self, input_data_point: str, documents: List[Any], document_vectors: List[Any]) -> List[Any]:
|
||||
"""
|
||||
Public method to retrieve top documents for an input query
|
||||
|
||||
Args:
|
||||
input_data_point: The query to search for
|
||||
documents: Documents to search
|
||||
document_vectors: Vectors for the documents
|
||||
|
||||
Returns:
|
||||
List of potentially relevant documents
|
||||
"""
|
||||
result = self.control_flow(input_data_point, documents, document_vectors)
|
||||
return result["relevant_documents"]
|
||||
|
||||
def _encode_query(self, input_data_point: str) -> Any:
|
||||
"""
|
||||
Encode the input query into a vector representation
|
||||
|
||||
Args:
|
||||
input_data_point: The query to encode
|
||||
|
||||
Returns:
|
||||
Vector representation of the query
|
||||
"""
|
||||
logger.debug(f"Encoding query: {input_data_point}")
|
||||
return self.encoder_service.encode(input_data_point)
|
||||
|
||||
def _get_documents_and_vectors(self) -> Tuple[List[Any], List[Any]]:
|
||||
"""
|
||||
Get all documents and their vectors from storage
|
||||
|
||||
Returns:
|
||||
Tuple of (documents, document_vectors)
|
||||
"""
|
||||
logger.debug("Getting documents and vectors from storage")
|
||||
return self.document_storage.get_documents_and_vectors()
|
||||
|
||||
def _similarity_search(self, encoded_query: Any, document_vectors: List[Any],
|
||||
doc_ids: List[str]) -> Tuple[List[float], List[str]]:
|
||||
"""
|
||||
Perform similarity search between the query and document vectors
|
||||
|
||||
Args:
|
||||
encoded_query: Vector representation of the query
|
||||
document_vectors: List of document vector embeddings
|
||||
doc_ids: List of document IDs corresponding to the vectors
|
||||
|
||||
Returns:
|
||||
Tuple of (similarity_scores, document_ids)
|
||||
"""
|
||||
logger.debug("Performing similarity search")
|
||||
|
||||
# In a real implementation, this would use an efficient vector search
|
||||
similarity_scores = []
|
||||
|
||||
for vector in document_vectors:
|
||||
if self.configs.ranking_method == "cosine_similarity":
|
||||
score = self._cosine_similarity(encoded_query, vector.get("embedding"))
|
||||
elif self.configs.ranking_method == "dot_product":
|
||||
score = self._dot_product(encoded_query, vector.get("embedding"))
|
||||
elif self.configs.ranking_method == "euclidean":
|
||||
score = self._euclidean_distance(encoded_query, vector.get("embedding"))
|
||||
# Convert distance to similarity score (higher is more similar)
|
||||
score = 1.0 / (1.0 + score)
|
||||
else:
|
||||
score = 0.0
|
||||
|
||||
similarity_scores.append(score)
|
||||
|
||||
# If the similarity search service is available, use it instead
|
||||
if self.similarity_search_service:
|
||||
return self.similarity_search_service.search(
|
||||
encoded_query, document_vectors, doc_ids
|
||||
)
|
||||
|
||||
return similarity_scores, doc_ids
|
||||
|
||||
def _rank_and_sort_results(self, similarity_scores: List[float],
|
||||
doc_ids: List[str]) -> List[Tuple[str, float]]:
|
||||
"""
|
||||
Rank and sort results by similarity score
|
||||
|
||||
Args:
|
||||
similarity_scores: List of similarity scores
|
||||
doc_ids: List of document IDs
|
||||
|
||||
Returns:
|
||||
List of (document_id, score) tuples sorted by score
|
||||
"""
|
||||
logger.debug("Ranking and sorting results")
|
||||
|
||||
# Create a list of (document_id, score) tuples
|
||||
result_tuples = list(zip(doc_ids, similarity_scores))
|
||||
|
||||
# Sort by score in descending order
|
||||
sorted_results = sorted(result_tuples, key=lambda x: x[1], reverse=True)
|
||||
|
||||
return sorted_results
|
||||
|
||||
def _filter_to_top_n(self, ranked_results: List[Tuple[str, float]]) -> List[str]:
|
||||
"""
|
||||
Filter to top N results
|
||||
|
||||
Args:
|
||||
ranked_results: List of (document_id, score) tuples
|
||||
|
||||
Returns:
|
||||
List of top N document IDs
|
||||
"""
|
||||
logger.debug(f"Filtering to top {self.configs.retrieval_count} results")
|
||||
|
||||
# Apply threshold filter if configured
|
||||
filtered_results = []
|
||||
|
||||
if self.configs.use_filter:
|
||||
for doc_id, score in ranked_results:
|
||||
if score >= self.configs.similarity_threshold:
|
||||
filtered_results.append(doc_id)
|
||||
else:
|
||||
filtered_results = [doc_id for doc_id, _ in ranked_results]
|
||||
|
||||
# Return top N results
|
||||
return filtered_results[:self.configs.retrieval_count]
|
||||
|
||||
def _retrieve_relevant_documents(self, documents: List[Any], top_doc_ids: List[str]) -> List[Any]:
|
||||
"""
|
||||
Retrieve potentially relevant documents
|
||||
|
||||
Args:
|
||||
documents: List of all documents
|
||||
top_doc_ids: List of top document IDs
|
||||
|
||||
Returns:
|
||||
List of potentially relevant documents
|
||||
"""
|
||||
logger.debug("Retrieving potentially relevant documents")
|
||||
|
||||
# Create a map of document ID to document for faster lookup
|
||||
doc_map = {doc.get("id"): doc for doc in documents}
|
||||
|
||||
# Retrieve documents by ID
|
||||
relevant_docs = []
|
||||
|
||||
for doc_id in top_doc_ids:
|
||||
if doc_id in doc_map:
|
||||
relevant_docs.append(doc_map[doc_id])
|
||||
|
||||
return relevant_docs
|
||||
|
||||
def _cosine_similarity(self, vec1: List[float], vec2: List[float]) -> float:
|
||||
"""Calculate cosine similarity between two vectors"""
|
||||
if not vec1 or not vec2:
|
||||
return 0.0
|
||||
|
||||
try:
|
||||
# Convert to numpy arrays for efficient calculation
|
||||
vec1_np = np.array(vec1)
|
||||
vec2_np = np.array(vec2)
|
||||
|
||||
# Calculate dot product
|
||||
dot = np.dot(vec1_np, vec2_np)
|
||||
|
||||
# Calculate norms
|
||||
norm1 = np.linalg.norm(vec1_np)
|
||||
norm2 = np.linalg.norm(vec2_np)
|
||||
|
||||
# Calculate cosine similarity
|
||||
similarity = dot / (norm1 * norm2)
|
||||
return float(similarity)
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating cosine similarity: {e}")
|
||||
return 0.0
|
||||
|
||||
def _dot_product(self, vec1: List[float], vec2: List[float]) -> float:
|
||||
"""Calculate dot product between two vectors"""
|
||||
if not vec1 or not vec2:
|
||||
return 0.0
|
||||
|
||||
try:
|
||||
# Convert to numpy arrays for efficient calculation
|
||||
vec1_np = np.array(vec1)
|
||||
vec2_np = np.array(vec2)
|
||||
|
||||
# Calculate dot product
|
||||
dot = np.dot(vec1_np, vec2_np)
|
||||
return float(dot)
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating dot product: {e}")
|
||||
return 0.0
|
||||
|
||||
def _euclidean_distance(self, vec1: List[float], vec2: List[float]) -> float:
|
||||
"""Calculate Euclidean distance between two vectors"""
|
||||
if not vec1 or not vec2:
|
||||
return float('inf')
|
||||
|
||||
try:
|
||||
# Convert to numpy arrays for efficient calculation
|
||||
vec1_np = np.array(vec1)
|
||||
vec2_np = np.array(vec2)
|
||||
|
||||
# Calculate Euclidean distance
|
||||
distance = np.linalg.norm(vec1_np - vec2_np)
|
||||
return float(distance)
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating Euclidean distance: {e}")
|
||||
return float('inf')
|
@ -0,0 +1,384 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Dict, List, Any, Optional, Union
|
||||
import logging
|
||||
from enum import Enum
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class BusinessOwnerAssumptions(BaseModel):
|
||||
"""Assumptions about the business owner"""
|
||||
has_annual_gross_income: str = "$70,000"
|
||||
|
||||
class BusinessAssumptions(BaseModel):
|
||||
"""Assumptions about the business"""
|
||||
year_of_operation: str = "second year"
|
||||
qualifies_for_incentives: bool = False
|
||||
gross_annual_revenue: str = "$1,000,000"
|
||||
employees: int = 15
|
||||
business_type: str = "general commercial activities (NAICS: 4523)"
|
||||
|
||||
class TaxesAssumptions(BaseModel):
|
||||
"""Assumptions about taxes"""
|
||||
taxes_paid_period: str = "second year of operation"
|
||||
|
||||
class OtherAssumptions(BaseModel):
|
||||
"""Other assumptions"""
|
||||
other_assumptions: List[str] = Field(default_factory=list)
|
||||
|
||||
class Assumptions(BaseModel):
|
||||
"""Collection of all assumptions"""
|
||||
business_owner: Optional[BusinessOwnerAssumptions] = None
|
||||
business: Optional[BusinessAssumptions] = None
|
||||
taxes: Optional[TaxesAssumptions] = None
|
||||
other: Optional[OtherAssumptions] = None
|
||||
|
||||
class PromptDecisionTreeNode(BaseModel):
|
||||
"""Node in the prompt decision tree"""
|
||||
prompt: str
|
||||
depends_on: Optional[List[str]] = None
|
||||
next_prompts: Optional[Dict[str, str]] = None
|
||||
|
||||
class Variable(BaseModel):
|
||||
"""Variable definition in the codebook"""
|
||||
label: str
|
||||
item_name: str
|
||||
description: str
|
||||
units: str
|
||||
assumptions: Optional[Assumptions] = None
|
||||
prompt_decision_tree: Optional[List[PromptDecisionTreeNode]] = None
|
||||
|
||||
class VariableCodebookConfigs(BaseModel):
|
||||
"""Configuration for Variable Codebook"""
|
||||
variables_path: str = "variables.json"
|
||||
load_from_file: bool = True
|
||||
cache_enabled: bool = True
|
||||
cache_ttl_seconds: int = 3600
|
||||
default_assumptions_enabled: bool = True
|
||||
|
||||
class VariableCodebook:
|
||||
"""
|
||||
Variable Codebook system based on mermaid class diagram in README.md
|
||||
Manages variable definitions and their associated assumptions and prompt sequences
|
||||
"""
|
||||
|
||||
def __init__(self, resources: Dict[str, Any], configs: VariableCodebookConfigs):
|
||||
"""
|
||||
Initialize with injected dependencies and configuration
|
||||
|
||||
Args:
|
||||
resources: Dictionary of resources including storage services
|
||||
configs: Configuration for Variable Codebook
|
||||
"""
|
||||
self.resources = resources
|
||||
self.configs = configs
|
||||
|
||||
# Extract needed services from resources
|
||||
self.storage_service = resources.get("storage_service")
|
||||
self.cache_service = resources.get("cache_service")
|
||||
|
||||
# Initialize variables dictionary
|
||||
self.variables: Dict[str, Variable] = {}
|
||||
|
||||
# Load variables if configured
|
||||
if self.configs.load_from_file:
|
||||
self._load_variables()
|
||||
|
||||
logger.info("VariableCodebook initialized with services")
|
||||
|
||||
def control_flow(self, action: str, **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute variable codebook operations based on the action
|
||||
|
||||
Args:
|
||||
action: Operation to perform (get_variable, get_prompt_sequence, etc.)
|
||||
**kwargs: Operation-specific parameters
|
||||
|
||||
Returns:
|
||||
Dictionary containing operation results
|
||||
"""
|
||||
logger.info(f"Starting variable codebook operation: {action}")
|
||||
|
||||
if action == "get_variable":
|
||||
return self._get_variable(
|
||||
variable_name=kwargs.get("variable_name", "")
|
||||
)
|
||||
elif action == "get_prompt_sequence":
|
||||
return self._get_prompt_sequence(
|
||||
variable_name=kwargs.get("variable_name", ""),
|
||||
input_data_point=kwargs.get("input_data_point", "")
|
||||
)
|
||||
elif action == "get_assumptions":
|
||||
return self._get_assumptions(
|
||||
variable_name=kwargs.get("variable_name", "")
|
||||
)
|
||||
elif action == "add_variable":
|
||||
return self._add_variable(
|
||||
variable=kwargs.get("variable")
|
||||
)
|
||||
elif action == "update_variable":
|
||||
return self._update_variable(
|
||||
variable_name=kwargs.get("variable_name", ""),
|
||||
variable=kwargs.get("variable")
|
||||
)
|
||||
else:
|
||||
logger.error(f"Unknown action: {action}")
|
||||
raise ValueError(f"Unknown action: {action}")
|
||||
|
||||
def get_prompt_sequence_for_input(self, input_data_point: str) -> List[str]:
|
||||
"""
|
||||
Get prompt sequence for a given input data point
|
||||
|
||||
Args:
|
||||
input_data_point: The query or information request
|
||||
|
||||
Returns:
|
||||
List of prompts in the sequence
|
||||
"""
|
||||
# Extract variable name from input data point
|
||||
variable_name = self._extract_variable_from_input(input_data_point)
|
||||
|
||||
# Get prompt sequence for the variable
|
||||
result = self.control_flow(
|
||||
"get_prompt_sequence",
|
||||
variable_name=variable_name,
|
||||
input_data_point=input_data_point
|
||||
)
|
||||
|
||||
return result.get("prompt_sequence", [])
|
||||
|
||||
def _extract_variable_from_input(self, input_data_point: str) -> str:
|
||||
"""
|
||||
Extract the variable name from the input data point
|
||||
|
||||
This is a simplified implementation that uses keyword matching.
|
||||
In a real system, this could use NLP techniques or more sophisticated parsing.
|
||||
|
||||
Args:
|
||||
input_data_point: The query or information request
|
||||
|
||||
Returns:
|
||||
Variable name
|
||||
"""
|
||||
# Convert to lowercase for case-insensitive matching
|
||||
input_lower = input_data_point.lower()
|
||||
|
||||
# Define keyword mappings to variable names
|
||||
keyword_mappings = {
|
||||
"sales tax": "sales_tax_city",
|
||||
"tax rate": "sales_tax_city",
|
||||
"local tax": "sales_tax_city",
|
||||
"city tax": "sales_tax_city",
|
||||
"municipal tax": "sales_tax_city",
|
||||
"property tax": "property_tax",
|
||||
"income tax": "income_tax"
|
||||
}
|
||||
|
||||
# Find the first matching keyword
|
||||
for keyword, variable in keyword_mappings.items():
|
||||
if keyword in input_lower:
|
||||
return variable
|
||||
|
||||
# Default to a generic variable if no match is found
|
||||
return "generic_tax_information"
|
||||
|
||||
def _get_variable(self, variable_name: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get a variable from the codebook
|
||||
|
||||
Args:
|
||||
variable_name: Name of the variable
|
||||
|
||||
Returns:
|
||||
Dictionary containing the variable information
|
||||
"""
|
||||
if variable_name in self.variables:
|
||||
return {
|
||||
"success": True,
|
||||
"variable": self.variables[variable_name]
|
||||
}
|
||||
else:
|
||||
logger.warning(f"Variable not found: {variable_name}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Variable not found: {variable_name}"
|
||||
}
|
||||
|
||||
def _get_prompt_sequence(self, variable_name: str, input_data_point: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get the prompt sequence for a variable
|
||||
|
||||
Args:
|
||||
variable_name: Name of the variable
|
||||
input_data_point: The query or information request
|
||||
|
||||
Returns:
|
||||
Dictionary containing the prompt sequence
|
||||
"""
|
||||
# Get the variable
|
||||
variable_result = self._get_variable(variable_name)
|
||||
|
||||
if not variable_result.get("success", False):
|
||||
return variable_result
|
||||
|
||||
variable = variable_result.get("variable")
|
||||
|
||||
# Extract the prompt sequence from the variable
|
||||
if not variable.prompt_decision_tree:
|
||||
logger.warning(f"No prompt decision tree found for variable: {variable_name}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"No prompt decision tree found for variable: {variable_name}"
|
||||
}
|
||||
|
||||
# Extract prompts from the decision tree
|
||||
prompts = [node.prompt for node in variable.prompt_decision_tree]
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"prompt_sequence": prompts,
|
||||
"variable": variable
|
||||
}
|
||||
|
||||
def _get_assumptions(self, variable_name: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get the assumptions for a variable
|
||||
|
||||
Args:
|
||||
variable_name: Name of the variable
|
||||
|
||||
Returns:
|
||||
Dictionary containing the assumptions
|
||||
"""
|
||||
# Get the variable
|
||||
variable_result = self._get_variable(variable_name)
|
||||
|
||||
if not variable_result.get("success", False):
|
||||
return variable_result
|
||||
|
||||
variable = variable_result.get("variable")
|
||||
|
||||
# Extract the assumptions from the variable
|
||||
return {
|
||||
"success": True,
|
||||
"assumptions": variable.assumptions,
|
||||
"variable": variable
|
||||
}
|
||||
|
||||
def _add_variable(self, variable: Variable) -> Dict[str, Any]:
|
||||
"""
|
||||
Add a variable to the codebook
|
||||
|
||||
Args:
|
||||
variable: Variable to add
|
||||
|
||||
Returns:
|
||||
Dictionary containing the operation result
|
||||
"""
|
||||
if variable.item_name in self.variables:
|
||||
logger.warning(f"Variable already exists: {variable.item_name}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Variable already exists: {variable.item_name}"
|
||||
}
|
||||
|
||||
# Add the variable
|
||||
self.variables[variable.item_name] = variable
|
||||
|
||||
# Save to storage if available
|
||||
if self.storage_service:
|
||||
self.storage_service.save_variable(variable)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"variable": variable
|
||||
}
|
||||
|
||||
def _update_variable(self, variable_name: str, variable: Variable) -> Dict[str, Any]:
|
||||
"""
|
||||
Update a variable in the codebook
|
||||
|
||||
Args:
|
||||
variable_name: Name of the variable to update
|
||||
variable: Updated variable
|
||||
|
||||
Returns:
|
||||
Dictionary containing the operation result
|
||||
"""
|
||||
if variable_name not in self.variables:
|
||||
logger.warning(f"Variable not found: {variable_name}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Variable not found: {variable_name}"
|
||||
}
|
||||
|
||||
# Update the variable
|
||||
self.variables[variable_name] = variable
|
||||
|
||||
# Save to storage if available
|
||||
if self.storage_service:
|
||||
self.storage_service.save_variable(variable)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"variable": variable
|
||||
}
|
||||
|
||||
def _load_variables(self) -> None:
|
||||
"""Load variables from storage"""
|
||||
try:
|
||||
if self.storage_service:
|
||||
variables = self.storage_service.load_variables(self.configs.variables_path)
|
||||
|
||||
if variables:
|
||||
self.variables = {var.item_name: var for var in variables}
|
||||
logger.info(f"Loaded {len(self.variables)} variables from storage")
|
||||
else:
|
||||
logger.warning("No variables found in storage")
|
||||
self._load_default_variables()
|
||||
else:
|
||||
logger.warning("No storage service available, loading default variables")
|
||||
self._load_default_variables()
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading variables: {e}")
|
||||
self._load_default_variables()
|
||||
|
||||
def _load_default_variables(self) -> None:
|
||||
"""Load default variables"""
|
||||
if not self.configs.default_assumptions_enabled:
|
||||
logger.info("Default assumptions disabled, skipping default variable loading")
|
||||
return
|
||||
|
||||
# Create a sample variable with assumptions and prompt decision tree
|
||||
sales_tax_variable = Variable(
|
||||
label="Sales Tax - City",
|
||||
item_name="sales_tax_city",
|
||||
description="A tax levied on the sales of all goods and services by the municipal government.",
|
||||
units="Double (Percent)",
|
||||
assumptions=Assumptions(
|
||||
business_owner=BusinessOwnerAssumptions(),
|
||||
business=BusinessAssumptions(),
|
||||
taxes=TaxesAssumptions(),
|
||||
other=OtherAssumptions(
|
||||
other_assumptions=["Also assume the business has no special tax exemptions."]
|
||||
)
|
||||
),
|
||||
prompt_decision_tree=[
|
||||
PromptDecisionTreeNode(
|
||||
prompt="List the name of the tax as given in the document verbatim, as well as its line item."
|
||||
),
|
||||
PromptDecisionTreeNode(
|
||||
prompt="List the formal definition of the tax verbatim, as well as its line item."
|
||||
),
|
||||
PromptDecisionTreeNode(
|
||||
prompt="Does this statute apply to all goods or services, or only to specific ones?"
|
||||
),
|
||||
PromptDecisionTreeNode(
|
||||
prompt="What is the exact percentage rate of the tax?"
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Add to variables dictionary
|
||||
self.variables[sales_tax_variable.item_name] = sales_tax_variable
|
||||
|
||||
logger.info("Loaded default variables")
|
248
custom_nodes/red_ribbon/socialtoolkit/mermaid_charts/README.md
Normal file
@ -0,0 +1,248 @@
|
||||
|
||||
# High Level Architecture
|
||||
```mermaid
|
||||
flowchart TD
|
||||
A[URLs to Pre-Approved Document Sources] --> B[Domain URLs]
|
||||
B --> C[Document Retrieval from Websites]
|
||||
C --> D[Documents, Document Metadata, & Document Vectors]
|
||||
D --> E[(Document Storage)]
|
||||
E --> F[Documents & Document Vectors]
|
||||
G[Desired Information: Local Sales Tax in Cheyenne, WY] --> H[Input Data Point]
|
||||
F --> I[Top 10 Document Retrieval]
|
||||
H --> I
|
||||
I --> J[Potentially Relevant Documents]
|
||||
J --> K[Relevance Assessment]
|
||||
L{Large Language Model: LLM} --> M[LLM API]
|
||||
M --> K
|
||||
N[Variable Codebook] --> O[Prompt Sequence]
|
||||
O --> K
|
||||
K --> P[Relevant Documents]
|
||||
P --> Q[Prompt Decision Tree]
|
||||
M --> Q
|
||||
Q --> R[Output Data Point]
|
||||
R --> S[Output Data Point: 6%]
|
||||
```
|
||||
|
||||
# Document Retrieval from Websites
|
||||
```mermaid
|
||||
flowchart TD
|
||||
A[Domain URLs] --> B[URL]
|
||||
B --> C{URL Path Generator}
|
||||
C -->|Static Webpages| D[Static Webpage Parser]
|
||||
C -->|Dynamic Webpages| E[Dynamic Webpage Parser]
|
||||
D --> F[Raw Data]
|
||||
E --> G[Raw Data]
|
||||
F --> H[Data Extractor]
|
||||
G --> H
|
||||
H --> I[Raw Strings]
|
||||
I --> J{Document Creator}
|
||||
J -->|Documents| K[Vector Generator]
|
||||
J -->|Documents| L[Metadata Generator]
|
||||
J -->|Documents| N[Document Storage]
|
||||
K --> M[Document Vectors]
|
||||
L --> O[Document Metadata]
|
||||
M --> N
|
||||
O --> N
|
||||
```
|
||||
|
||||
# Document Storage
|
||||
```mermaid
|
||||
erDiagram
|
||||
Sources ||--o{ Documents : contains
|
||||
Documents ||--o{ Versions : has
|
||||
Documents ||--o{ Metadatas : has
|
||||
Documents ||--o{ Contents : has
|
||||
Versions ||--o{ VersionsContents : has
|
||||
Contents ||--o{ Vectors : has
|
||||
|
||||
Sources {
|
||||
string id PK
|
||||
}
|
||||
|
||||
Documents {
|
||||
uuid document_id PK
|
||||
uuid source_id FK
|
||||
varchar url "Length 2300"
|
||||
json scraping_config
|
||||
datetime last_scrape
|
||||
datetime last_successful_scrape
|
||||
uuid current_version_id FK "Updated to latest version"
|
||||
enum status "new, processing, complete, error"
|
||||
tinyint priority "1-5, with 1 being most important. Default: 5"
|
||||
varchar document_type "html, pdf, etc."
|
||||
}
|
||||
|
||||
Versions {
|
||||
uuid version_id PK
|
||||
uuid document_id FK
|
||||
varchar perm_url "Internet Archive, Libgen | Length 2200"
|
||||
boolean current_version
|
||||
string version_number
|
||||
enum status "draft, active, superseded"
|
||||
text change_summary
|
||||
datetime effective_date
|
||||
datetime processed_at
|
||||
}
|
||||
|
||||
Metadatas {
|
||||
uuid metadata_id PK
|
||||
uuid document_id FK
|
||||
json other_metadata
|
||||
varchar local_file_path
|
||||
datetime created_at
|
||||
datetime updated_at
|
||||
}
|
||||
|
||||
Contents {
|
||||
uuid content_id PK
|
||||
uuid version_id FK
|
||||
longtext raw_content
|
||||
longtext processed_content
|
||||
json structure_data
|
||||
varchar location_in_doc "Ex: Page numbers in PDF & Docs. Default to NULL"
|
||||
binary hash "SHA 256 on raw_content, virtual column"
|
||||
}
|
||||
|
||||
VersionsContents {
|
||||
uuid version_id FK
|
||||
uuid content_id FK
|
||||
datetime created_at
|
||||
enum source_type "primary, secondary, tertiary"
|
||||
}
|
||||
|
||||
Vectors {
|
||||
uuid vector_id PK
|
||||
uuid content_id FK
|
||||
embedding vector_embedding
|
||||
enum embedding_type
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
# Top-10 Document Retrieval
|
||||
```mermaid
|
||||
flowchart TD
|
||||
A[Desired Information: Local Sales Tax in Cheyenne, WY] --> B[Input Data Point]
|
||||
B --> C[Encode Query]
|
||||
D[(Document Storage)] --> E[Vector Embeddings]
|
||||
D --> F[Document IDs]
|
||||
C --> G[Encoded Query]
|
||||
E --> H[Similarity Search]
|
||||
F --> H
|
||||
G --> H
|
||||
H --> I[Similarity Scores & Document IDs]
|
||||
I --> J[Rank & Sort Results]
|
||||
J --> K[Sorted Document IDs]
|
||||
K --> L[Filter to Top-10 Results]
|
||||
L --> M[Potentially Relevant Documents IDs]
|
||||
D --> N[Documents]
|
||||
N --> O[Pull Relevant Documents]
|
||||
M --> O
|
||||
O --> P[Potentially Relevant Documents]
|
||||
P --> Q[Relevance Assessment]
|
||||
```
|
||||
|
||||
# Variable Codebook
|
||||
```mermaid
|
||||
classDiagram
|
||||
class Variable {
|
||||
+String label: "Sales Tax - City"
|
||||
+String itemName: "sales_tax_city"
|
||||
+String description: "A tax levied on the sales of all goods and services by the municipal government."
|
||||
+String units: "Double (Percent)"
|
||||
+Map assumptions
|
||||
+List promptDecisionTree
|
||||
}
|
||||
|
||||
class Assumptions {
|
||||
BusinessOwner
|
||||
Business
|
||||
Taxes
|
||||
}
|
||||
|
||||
class BusinessOwnerAssumptions {
|
||||
+String hasAnnualGrossIncome: "$70,000"
|
||||
}
|
||||
|
||||
class BusinessAssumptions {
|
||||
+String yearOfOperation: "second year"
|
||||
+Boolean qualifiesForIncentives: false
|
||||
+Number grossAnnualRevenue: "$1,000,000"
|
||||
+Number employees: 15
|
||||
+String businessType: "general commercial activities (NAICS: 4523)"
|
||||
}
|
||||
|
||||
class TaxesAssumptions {
|
||||
+String taxesPaidPeriod: "second year of operation"
|
||||
}
|
||||
|
||||
class OtherAssumptions {
|
||||
+Strings otherAssumptions: "Also assume..."
|
||||
}
|
||||
|
||||
class PromptDecisionTree {
|
||||
+String prompt1: "List the name of the tax as given in the document verbatim, as well as its line item."
|
||||
+String prompt2: "List the formal definition of the tax verbatim, as well as its line item."
|
||||
+String prompt3: "Does this statute apply to all goods or services, or only to specific ones?"
|
||||
+String prompt4: "..."
|
||||
}
|
||||
|
||||
Variable --> Assumptions
|
||||
Variable --> PromptDecisionTree
|
||||
Assumptions --> BusinessOwnerAssumptions
|
||||
Assumptions --> BusinessAssumptions
|
||||
Assumptions --> TaxesAssumptions
|
||||
Assumptions --> OtherAssumptions
|
||||
```
|
||||
|
||||
# Relevance Assessment
|
||||
```mermaid
|
||||
flowchart TD
|
||||
A[Top 10 Document Retrieval] --> B[Potentially Relevant Documents]
|
||||
C[Variable Codebook] --> D[Variable Definition & Description]
|
||||
E{Large Language Model: LLM} --> F[LLM API]
|
||||
B --> G[Document Relevance Assessment]
|
||||
D --> G
|
||||
G --> H[LLM Hallucination]
|
||||
F --> H
|
||||
G --> I[LLM Assessment & Text Citation]
|
||||
I --> H
|
||||
H --> J[LLM Assessment]
|
||||
J --> K[Relevance Scorer]
|
||||
K --> L[Page Relevance Score]
|
||||
L --> M{Criteria Threshold Check}
|
||||
B --> N[Potentially Relevant Documents]
|
||||
N --> K
|
||||
M --> O[Page Relevance Score < Criteria Threshold] & P[Page Relevance Score >= Criteria Threshold]
|
||||
O --> Q[Discarded Documents Pages Pool]
|
||||
P --> R[Relevant Document Pages Pool]
|
||||
R --> S[Page Numbers]
|
||||
A --> T[Potentially Relevant Documents]
|
||||
T --> U[Cited Page Extractor]
|
||||
S --> U
|
||||
U --> V[Relevant Pages]
|
||||
V --> W[Prompt Decision Tree]
|
||||
```
|
||||
|
||||
# Prompt Decision Tree
|
||||
```mermaid
|
||||
flowchart TD
|
||||
A[Relevant Pages] --> B[Concatenate Pages]
|
||||
B --> C[Concatenated Pages]
|
||||
D{Large Language Model: LLM} --> E[LLM API]
|
||||
F[Variable Codebook] --> G[Desired Data Point Codebook Entry & Prompt Sequence]
|
||||
C --> H[Prompt Decision Tree]
|
||||
E --> H
|
||||
G --> H
|
||||
H --> I[Prompt A: List the name of the tax...]
|
||||
I --> J[Edge]
|
||||
J --> K{Prompt E: List the formal definition...}
|
||||
K --> L[Edge] & M[Edge]
|
||||
L --> N[Prompt C: Does this statute apply to all goods or services...]
|
||||
M --> O{Prompt N:...}
|
||||
N --> P[Final Response]
|
||||
O --> Q[Final Response]
|
||||
P --> R[Output Data Point]
|
||||
Q --> R
|
||||
S[Errors & Unforeseen Edgecases] --> T[Human Review]
|
||||
```
|
After Width: | Height: | Size: 30 KiB |
After Width: | Height: | Size: 72 KiB |
After Width: | Height: | Size: 32 KiB |
After Width: | Height: | Size: 71 KiB |
After Width: | Height: | Size: 35 KiB |
After Width: | Height: | Size: 26 KiB |
After Width: | Height: | Size: 51 KiB |
57
custom_nodes/red_ribbon/socialtoolkit/socialtoolkit.py
Normal file
@ -0,0 +1,57 @@
|
||||
"""
|
||||
Social Toolkit - Main entrance file for social media integration tools
|
||||
"""
|
||||
|
||||
from . import SocialToolkitNode
|
||||
from .architecture.document_retrieval_from_websites import DocumentRetrievalFromWebsites
|
||||
from .architecture.document_storage import DocumentStorage
|
||||
from .architecture.llm_service import LLMService
|
||||
from .architecture.high_level_architecture import Socialtoolkit
|
||||
from .architecture.codebook import Codebook
|
||||
|
||||
class SocialToolkitAPI:
|
||||
"""API for accessing Social Toolkit functionality from other modules"""
|
||||
|
||||
def __init__(self, resources, configs):
|
||||
self.configs = configs
|
||||
self.resources = resources
|
||||
|
||||
self._document_retrieval_from_websites = self.resources["document_retrieval_from_websites"] or DocumentRetrievalFromWebsites(resources, configs)
|
||||
self._document_storage = self.resources["document_storage"] or DocumentStorage(resources, configs)
|
||||
self.llm_service = self.resources["llm_service"] or LLMService(resources, configs)
|
||||
self.control_flow = self.resources["socialtoolkit"] or Socialtoolkit(resources, configs)
|
||||
self.codebook = self.resources["codebook"] or Codebook(resources, configs)
|
||||
|
||||
def document_retrieval_from_websites(self, domain_urls: list[str]) -> tuple['Document', 'Metadata', 'Vectors']:
|
||||
return self._document_retrieval_from_websites.execute(domain_urls)
|
||||
|
||||
def document_storage(self):
|
||||
return self._document_storage.execute()
|
||||
|
||||
def llm_service(self):
|
||||
pass
|
||||
|
||||
def control_flow(self):
|
||||
pass
|
||||
|
||||
|
||||
# Main function that can be called when using this as a script
|
||||
def main():
|
||||
"""Main function for Socialtoolkit module"""
|
||||
configs = Configs()
|
||||
resources = {
|
||||
"document_retrieval_from_websites": DocumentRetrievalFromWebsites(resources, configs),
|
||||
"document_storage": DocumentStorage(resources, configs),
|
||||
"llm_service": LLMService(resources, configs),
|
||||
"socialtoolkit": Socialtoolkit(resources, configs),
|
||||
"codebook": Codebook(resources, configs)
|
||||
}
|
||||
|
||||
|
||||
print("Social Toolkit module loaded successfully")
|
||||
print("Available tools:")
|
||||
print("- SocialToolkitNode: Node for ComfyUI integration")
|
||||
print("- SocialToolkitAPI: API for programmatic access")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
58
custom_nodes/red_ribbon/utils/__init__.py
Normal file
@ -0,0 +1,58 @@
|
||||
"""
|
||||
Utility functions for Red Ribbon custom nodes
|
||||
"""
|
||||
|
||||
def merge_node_mappings(mappings_list):
|
||||
"""
|
||||
Merge multiple node class mappings into one
|
||||
|
||||
Args:
|
||||
mappings_list: List of (NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS) tuples
|
||||
|
||||
Returns:
|
||||
tuple: Combined (NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS)
|
||||
"""
|
||||
combined_class_mappings = {}
|
||||
combined_display_mappings = {}
|
||||
|
||||
for class_mapping, display_mapping in mappings_list:
|
||||
combined_class_mappings.update(class_mapping)
|
||||
if display_mapping:
|
||||
combined_display_mappings.update(display_mapping)
|
||||
|
||||
return combined_class_mappings, combined_display_mappings
|
||||
|
||||
class UtilityNode:
|
||||
"""Node with utility functions for Red Ribbon"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"mode": (["debug", "info", "log"], {"default": "info"}),
|
||||
"message": ("STRING", {"multiline": True}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("STRING",)
|
||||
FUNCTION = "log"
|
||||
CATEGORY = "Red Ribbon/Utils"
|
||||
|
||||
def log(self, mode, message):
|
||||
# Log the message according to the specified mode
|
||||
formatted = f"[{mode.upper()}] {message}"
|
||||
print(formatted)
|
||||
return (formatted,)
|
||||
|
||||
# Dictionary of nodes to be imported by main.py
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"UtilityNode": UtilityNode
|
||||
}
|
||||
|
||||
# Add display names for the nodes
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"UtilityNode": "Red Ribbon Utility"
|
||||
}
|
||||
|
||||
def utils():
|
||||
return NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
|
39
custom_nodes/red_ribbon/utils/utils.py
Normal file
@ -0,0 +1,39 @@
|
||||
"""
|
||||
Utils - Main entrance file for Red Ribbon utility functions
|
||||
"""
|
||||
|
||||
from . import UtilityNode, merge_node_mappings
|
||||
|
||||
class UtilsAPI:
|
||||
"""API for accessing utility functionality from other modules"""
|
||||
|
||||
@staticmethod
|
||||
def log_message(mode, message):
|
||||
"""Log a message with the specified mode
|
||||
|
||||
Args:
|
||||
mode (str): Log mode (debug, info, log)
|
||||
message (str): The message to log
|
||||
|
||||
Returns:
|
||||
str: Formatted log message
|
||||
"""
|
||||
formatted = f"[{mode.upper()}] {message}"
|
||||
print(formatted)
|
||||
return formatted
|
||||
|
||||
@staticmethod
|
||||
def combine_mappings(mappings_list):
|
||||
"""Wrapper for merge_node_mappings function"""
|
||||
return merge_node_mappings(mappings_list)
|
||||
|
||||
# Main function that can be called when using this as a script
|
||||
def main():
|
||||
print("Red Ribbon Utils module loaded successfully")
|
||||
print("Available tools:")
|
||||
print("- UtilityNode: Node for ComfyUI integration")
|
||||
print("- UtilsAPI: API for programmatic access")
|
||||
print("- merge_node_mappings: Function for merging node mappings")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
41
install.sh
Normal file
@ -0,0 +1,41 @@
|
||||
#!/bin/bash
|
||||
|
||||
echo "Setting up the environment..."
|
||||
|
||||
# Check if Python is installed
|
||||
if ! command -v python3 &> /dev/null
|
||||
then
|
||||
echo "Python is not installed. Please install Python 3.7 or later and add it to your PATH."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Check if the virtual environment already exists
|
||||
if [ -d "venv" ]; then
|
||||
echo "Virtual environment already exists. Skipping creation."
|
||||
else
|
||||
# Create a virtual environment if it doesn't exist
|
||||
echo "Creating a virtual environment..."
|
||||
python3 -m venv venv
|
||||
fi
|
||||
|
||||
# Activate the virtual environment
|
||||
echo "Activating the virtual environment 'venv'..."
|
||||
source venv/bin/activate
|
||||
|
||||
|
||||
# Install required packages from requirements.txt
|
||||
if [[ -f "requirements.txt" ]]; then
|
||||
echo "Installing required packages..."
|
||||
pip install -r requirements.txt
|
||||
else
|
||||
echo "requirements.txt not found. Skipping package installation."
|
||||
fi
|
||||
|
||||
if [[ -f "requirements_custom_nodes.txt" ]]; then
|
||||
echo "Installing packages for custom nodes..."
|
||||
pip install -r requirements_custom_nodes.txt
|
||||
else
|
||||
echo "requirements_custom_nodes.txt not found. Skipping package installation for custom nodes."
|
||||
fi
|
||||
|
||||
echo "Setup complete!"
|
52
requirements_custom_nodes.txt
Normal file
@ -0,0 +1,52 @@
|
||||
# torch
|
||||
# torchsde
|
||||
# torchvision
|
||||
# torchaudio
|
||||
# numpy>=1.25.0
|
||||
# einops
|
||||
# transformers>=4.28.1
|
||||
# tokenizers>=0.13.3
|
||||
# sentencepiece
|
||||
# safetensors>=0.4.2
|
||||
# aiohttp>=3.11.8
|
||||
# yarl>=1.18.0
|
||||
# pyyaml
|
||||
# Pillow
|
||||
# scipy
|
||||
# tqdm
|
||||
# psutil
|
||||
|
||||
# Custom node depdencies
|
||||
tiktoken
|
||||
duckdb
|
||||
beautifulsoup4
|
||||
html2text
|
||||
huggingface_hub
|
||||
openai
|
||||
pandas
|
||||
pyarrow
|
||||
pydantic
|
||||
pytest
|
||||
pytest-asyncio
|
||||
requests
|
||||
anthropic
|
||||
aiofiles
|
||||
cohere
|
||||
httpx
|
||||
rasterio
|
||||
ComfyUI-EasyNodes
|
||||
mysql-connector-python
|
||||
networkx
|
||||
autoscraper
|
||||
backoff
|
||||
playwright
|
||||
PyPDF2
|
||||
pytest-playwright
|
||||
matplotlib
|
||||
|
||||
# non essential dependencies:
|
||||
kornia>=0.7.1
|
||||
spandrel
|
||||
soundfile
|
||||
av
|
||||
|
22
start.sh
Normal file
@ -0,0 +1,22 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Echo to indicate start of the program
|
||||
echo "STARTING PROGRAM..."
|
||||
|
||||
# Activate the virtual environment
|
||||
source venv/bin/activate
|
||||
|
||||
# Echo to indicate the start of the Python script
|
||||
echo "*** BEGIN PROGRAM ***"
|
||||
|
||||
# Run the Python script
|
||||
python main.py # main.py
|
||||
|
||||
# Echo to indicate the end of the Python script
|
||||
echo "*** END PROGRAM ***"
|
||||
|
||||
# Deactivate the virtual environment
|
||||
deactivate
|
||||
|
||||
# Echo to indicate program completion
|
||||
echo "PROGRAM EXECUTION COMPLETE."
|