Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory Management ;( the runtime gets crashed after few prompts #109

Open
shiroanon opened this issue Feb 27, 2025 · 0 comments
Open

Memory Management ;( the runtime gets crashed after few prompts #109

shiroanon opened this issue Feb 27, 2025 · 0 comments

Comments

@shiroanon
Copy link

shiroanon commented Feb 27, 2025

import random
import string
import sys
from typing import Sequence, Mapping, Any, Union
import torch
from flask import Flask, request, jsonify ,send_from_directory

from flask_cors import CORS
app = Flask(__name__)
CORS(app)
def random_strings_list(n):
    return [''.join(random.choices(string.ascii_letters + string.digits, k=16)) for _ in range(n)]


import threading
def tu():
  !/content/loophole http 5000
threading.Thread(target=tu , daemon=True).start()

import shiro.utils
# Lets manipulate comfy through args switches
from shiro.cli_args import args, LatentPreviewMethod
args.preview_method = LatentPreviewMethod.Latent2RGB
def f(value, total, preview):
		if preview:
			preview[1].save('/tmp/preview.jpg')
shiro.utils.set_progress_bar_global_hook(f)


def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:

    try:
        return obj[index]
    except KeyError:
        return obj["result"][index]


def find_path(name: str, path: str = None) -> str:

    # If no path is given, use the current working directory
    if path is None:
        path = os.getcwd()

    # Check if the current directory contains the name
    if name in os.listdir(path):
        path_name = os.path.join(path, name)
        print(f"{name} found: {path_name}")
        return path_name

    # Get the parent directory
    parent_directory = os.path.dirname(path)

    # If the parent directory is the same as the current directory, we've reached the root and stop the search
    if parent_directory == path:
        return None

    # Recursively call the function with the parent directory
    return find_path(name, parent_directory)


def add_shiroui_directory_to_sys_path() -> None:
    shiroui_path = find_path("ShiroUI")
    if shiroui_path is not None and os.path.isdir(shiroui_path):
        sys.path.append(shiroui_path)
        print(f"'{shiroui_path}' added to sys.path")

add_shiroui_directory_to_sys_path()



def import_custom_nodes() -> None:
    import asyncio
    import execution
    from nodes import init_extra_nodes
    import server

    # Creating a new event loop and setting it as the default loop
    loop = asyncio.new_event_loop()
    asyncio.set_event_loop(loop)

    # Creating an instance of PromptServer with the loop
    server_instance = server.PromptServer(loop)
    execution.PromptQueue(server_instance)

    # Initializing custom nodes
    init_extra_nodes()


from nodes import (
    CLIPTextEncode,
    CheckpointLoaderSimple,
    VAEDecode,
    NODE_CLASS_MAPPINGS,
    SaveImage,
    LoraLoader,
    EmptyLatentImage,
)


def main(prompt,cf,batch_siz):
    import_custom_nodes()
    global lis
    lis=random_strings_list(1)
    with torch.inference_mode():
        checkpointloadersimple = CheckpointLoaderSimple()
        checkpointloadersimple_1 = checkpointloadersimple.load_checkpoint(
            ckpt_name="kk.safetensors"
        )

        loraloader = LoraLoader()
        loraloader_10 = loraloader.load_lora(
            lora_name="flat.safetensors",
            strength_model=1,
            strength_clip=1,
            model=get_value_at_index(checkpointloadersimple_1, 0),
            clip=get_value_at_index(checkpointloadersimple_1, 1),
        )

        loraloader_11 = loraloader.load_lora(
            lora_name="flat.safetensors",
            strength_model=0,
            strength_clip=0,
            model=get_value_at_index(loraloader_10, 0),
            clip=get_value_at_index(loraloader_10, 1),
        )

        loraloader_12 = loraloader.load_lora(
            lora_name="flat.safetensors",
            strength_model=0,
            strength_clip=0,
            model=get_value_at_index(loraloader_11, 0),
            clip=get_value_at_index(loraloader_11, 1),
        )

        cliptextencode = CLIPTextEncode()
        cliptextencode_3 = cliptextencode.encode(
            text=prompt, clip=get_value_at_index(loraloader_12, 1)
        )

        cliptextencode_4 = cliptextencode.encode(
            text="", clip=get_value_at_index(loraloader_12, 1)
        )

        alignyourstepsscheduler = NODE_CLASS_MAPPINGS["AlignYourStepsScheduler"]()
        alignyourstepsscheduler_5 = alignyourstepsscheduler.get_sigmas(
            model_type="SD1", steps=25, denoise=1
        )

        ksamplerselect = NODE_CLASS_MAPPINGS["KSamplerSelect"]()
        ksamplerselect_6 = ksamplerselect.get_sampler(sampler_name="euler")

        emptylatentimage = EmptyLatentImage()
        emptylatentimage_7 = emptylatentimage.generate(
            width=512, height=512, batch_size=batch_siz
        )

        samplercustom = NODE_CLASS_MAPPINGS["SamplerCustom"]()
        vaedecode = VAEDecode()
        saveimage = SaveImage()

        for q in range(1):
            samplercustom_2 = samplercustom.sample(
                add_noise=True,
                noise_seed=random.randint(1, 2**64),
                cfg=cf,
                model=get_value_at_index(checkpointloadersimple_1, 0),
                positive=get_value_at_index(cliptextencode_3, 0),
                negative=get_value_at_index(cliptextencode_4, 0),
                sampler=get_value_at_index(ksamplerselect_6, 0),
                sigmas=get_value_at_index(alignyourstepsscheduler_5, 0),
                latent_image=get_value_at_index(emptylatentimage_7, 0),
            )

            vaedecode_8 = vaedecode.decode(
                samples=get_value_at_index(samplercustom_2, 0),
                vae=get_value_at_index(checkpointloadersimple_1, 2),
            )

            saveimage_9 = saveimage.save_images(
                filename_prefix=lis[0], images=get_value_at_index(vaedecode_8, 0)
            )



@app.route('/generate', methods=['POST'])
def generate():
    data = request.json
    prompt = data.get('prompt', '')
    cfg = data.get('cfg', 1)
    batch_size = data.get('batch_size', 1)
    seed = data.get('seed', 0)
    global response

    response = {
        "prompt": prompt,
        "cfg": cfg,
        "batch_size": batch_size,
        "seed": seed
    }
    print(response)
    main(prompt,cfg,batch_size)

    query = lis[0]
    directory = "/content/ShiroUI/output"

    if not query or not directory or not os.path.isdir(directory):
        return jsonify({"error": "Invalid query or directory"}), 400

    matched_images = [
        os.path.join("output", f) for f in os.listdir(directory)
        if query in f and f.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp', '.webp'))
    ]

    if not matched_images:
        return jsonify({"error": "No matching images found"}), 404

    return jsonify(matched_images)


@app.route('/output/<path:filename>', methods=['GET'])
def get_image(filename):
    directory = "/content/ShiroUI/output"
    return send_from_directory(directory, filename)

if __name__ == '__main__':
    app.run()



Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant