-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
129 lines (101 loc) · 3.92 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import streamlit as st
from PIL import Image
from io import BytesIO
from time import time
from rembg import remove
from img2img import ImageConvert
default_prompt = "((best quality)), (detailed), cartoon"
sessions = ["uploading", "rerun", "gender", "body", "style", "hair", "bg", "strength", "control_scale"]
st.set_page_config(layout="wide", page_title="Cartoonize Everything|万物皆可萌")
hide_menu_style = """
<style>
#MainMenu {visibility: hidden;}
footer {visibility: hidden;}
</style>
"""
st.markdown(hide_menu_style, unsafe_allow_html=True)
st.title(f"_Cartoonize Everything|万物皆可萌_")
st.sidebar.write("Upload and download")
@st.cache_resource
def get_model():
return ImageConvert()
def convert_image(img):
buf = BytesIO()
img.save(buf, format="PNG")
byte_im = buf.getvalue()
return byte_im
def generate_image(img, model, qr_data):
prompt = default_prompt
for s in sessions:
prompt += append_prompt(s)
if st.session_state["bg"] is not None and st.session_state["bg"] != "default":
img = remove(img)
generated = model.generate_image(
img,
prompt,
qr_data=qr_data,
strength=st.session_state["strength"],
scale=st.session_state["control_scale"],
)
st.sidebar.markdown("\n")
st.sidebar.download_button(
"Download generated image",
convert_image(generated),
f"{int(time()*100000)}.png",
"image/png",
)
return generated
def upload_trigger():
st.session_state["uploading"] = True
def rerun_trigger():
st.session_state["rerun"] = True
def init_sessions():
for s in sessions:
if s not in st.session_state:
st.session_state[s] = None
def append_prompt(item, prefix="", suffix=""):
res = ""
value = st.session_state[item]
if isinstance(value, str) and value != "default":
res = f", {prefix}{value}{suffix}"
return res
mode = st.radio(
"Select cartoonization mode",
["photo", "qr code"],
horizontal=True,
)
if mode == "qr code":
qr_content = st.text_input("Input the data you want to embed into the QR code", "https://cartoonme.fun")
else:
qr_content = None
imageConvertModel = get_model()
col1, col2 = st.columns(2)
my_upload = st.sidebar.file_uploader("Upload an image", type=["png", "jpg", "jpeg"], on_change=upload_trigger)
col1.write("Original Image")
col2.write("Generated Image")
init_sessions()
if "generated_image" not in st.session_state:
st.session_state["generated_image"] = Image.open("./example.png")
st.write(
f"**Disclaimer:** we promise all the uploaded and generated images will be completely deleted "
f"as soon as you close the browser."
)
initial_image = my_upload if my_upload else "./test.png"
col1.image(Image.open(initial_image))
if st.session_state["uploading"] or st.session_state["rerun"]:
st.session_state["generated_image"] = generate_image(Image.open(initial_image), imageConvertModel, qr_content)
st.session_state["uploading"] = False
st.session_state["rerun"] = False
col2.image(st.session_state["generated_image"])
with st.expander(f"**_Not what you want? Let's make it better!_**"):
sub_col1, sub_col2 = st.columns(2)
st.session_state["gender"] = sub_col1.selectbox("gender", ("default", "girl", "boy"))
st.session_state["bg"] = sub_col2.selectbox(
"background", ("default", "seaside", "city landscape", "blue sky", "flowers")
)
st.session_state["style"] = sub_col1.selectbox("style", ("default", "vintage", "sci-fi", "realistic"))
st.session_state["hair"] = sub_col2.selectbox("hair", ("default", "bangs hair", "mohawk", "ponytail", "long hair"))
if mode == "qr code":
st.session_state["strength"] = sub_col1.slider("strength", 0.3, 0.9, 0.8)
st.session_state["control_scale"] = sub_col2.slider("control_scale", 1.2, 10.0, 1.8)
rerun = st.button("Re-generate", on_click=rerun_trigger)