Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

delete and Select All #39

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
219 changes: 140 additions & 79 deletions rope/GUI.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import bisect
import torch
import torchvision
from send2trash import send2trash

torchvision.disable_beta_transforms_warning()
import mimetypes
Expand Down Expand Up @@ -67,6 +68,7 @@ def __init__(self, models):
self.static_widget = {}
self.layer = {}


self.temp_emb = []

self.arcface_dst = np.array([[38.2946, 51.6963], [73.5318, 51.5014], [56.0252, 71.7366], [41.5493, 92.3655], [70.7299, 92.2041]], dtype=np.float32)
Expand Down Expand Up @@ -104,10 +106,75 @@ def __init__(self, models):
"Embedding": []
}
self.source_faces = []
self.bind('<Delete>', self.delete_selected_face)
print("Delete键绑定已设置")
self.bind('<Control-a>', self.select_all_faces) # 添加这一行
print("Delete键和Ctrl+A键绑定已设置")



#####

def select_all_faces(self, event):
print("select_all_faces 方法被调用")
for face in self.source_faces:
face["ButtonState"] = True
face["TKButton"].config(style.media_button_on_3)
self.update_faces_canvas()
print("所有faces已被选中并处理")
self._update_target_face_assignments()

# 更新target faces (如果需要的话)
# self.update_target_faces()

def delete_selected_face(self, event):
print("delete_selected_face 方法被调用")
for i, face in enumerate(self.source_faces):
if face["ButtonState"]:
print(f"找到选中的face: {i}")

# 规范化文件路径
file_path = os.path.normpath(face["file"])
print(f"规范化后的文件路径: {file_path}")

# 检查文件是否存在
if not os.path.exists(file_path):
print(f"文件不存在: {file_path}")
return

# 将文件发送到回收站
try:
send2trash(file_path)
print(f"已将文件发送到回收站: {file_path}")
except Exception as e:
print(f"发送文件到回收站时出错: {str(e)}")
return # 如果操作失败,我们不应继续

# 从列表中移除face
self.source_faces.pop(i)
print(f"已从source_faces列表中移除face {i}")

# 更新画布,而不是重新加载
self.update_faces_canvas()
print("更新了faces画布")
break
else:
print("没有找到选中的face")

def update_faces_canvas(self):
self.source_faces_canvas.delete("all") # 清除画布上的所有内容

for i, face in enumerate(self.source_faces):
button_style = style.media_button_on_3 if face["ButtonState"] else style.media_button_off_3
face["TKButton"] = tk.Button(self.source_faces_canvas, button_style, image=face["Image"], height=90, width=90)
face["TKButton"].bind("<ButtonRelease-1>", lambda event, arg=i: self.select_input_faces(event, arg))
face["TKButton"].bind("<MouseWheel>", self.source_faces_mouse_wheel)

self.source_faces_canvas.create_window((i % 2) * 100, (i // 2) * 100, window=face["TKButton"], anchor='nw')
# print(f"重新绘制了face {i}, 选中状态: {face['ButtonState']}")

self.static_widget['input_faces_scrollbar'].resize_scrollbar(None)

def create_gui(self):

# 1 x 3 Top level grid
Expand Down Expand Up @@ -915,6 +982,7 @@ def select_faces_path(self):
self.load_input_faces()

def load_input_faces(self):
print("load_input_faces 方法被调用")
self.source_faces = []
self.merged_faces_canvas.delete("all")
self.source_faces_canvas.delete("all")
Expand Down Expand Up @@ -945,26 +1013,23 @@ def load_input_faces(self):
self.merged_faces_canvas.configure(scrollregion = self.merged_faces_canvas.bbox("all"))
self.merged_faces_canvas.xview_moveto(0)

except:
pass
print(f"加载了 {len(temp0)} 个合并的embeddings")
except Exception as e:
print(f"加载合并embeddings时出错: {str(e)}")

self.shift_i_len = len(self.source_faces)

# Next Load images
directory = self.json_dict["source faces"]
filenames = [os.path.join(dirpath,f) for (dirpath, dirnames, filenames) in os.walk(directory) for f in filenames]

# torch.cuda.memory._record_memory_history(True, trace_alloc_max_entries=100000, trace_alloc_record_context=True)
i=0
for file in filenames: # Does not include full path
# Find all faces and ad to faces[]
# Guess File type based on extension
for file in filenames:
try:
file_type = mimetypes.guess_type(file)[0][:5]
except:
print('Unrecognized file type:', file)
print('无法识别的文件类型:', file)
else:
# Its an image
if file_type == 'image':
img = cv2.imread(file)

Expand All @@ -989,7 +1054,7 @@ def load_input_faces(self):
try:
kpss = self.models.run_detect(img, max_num=1)[0] # Just one face here
except IndexError:
print('Image cropped too close:', file)
print('图像裁剪过近:', file)
else:
face_emb, cropped_image = self.models.run_recognize(img, kpss)
crop = cv2.cvtColor(cropped_image.cpu().numpy(), cv2.COLOR_BGR2RGB)
Expand All @@ -1011,12 +1076,13 @@ def load_input_faces(self):

self.static_widget['input_faces_scrollbar'].resize_scrollbar(None)
i = i + 1
# print(f"加载了图像: {file}")

else:
print('Bad file', file)

print('无效的文件', file)

torch.cuda.empty_cache()
print(f"总共加载了 {i} 个图像文件")

def find_faces(self):
try:
Expand Down Expand Up @@ -1106,90 +1172,85 @@ def toggle_found_faces_buttons_state(self, button):
self.source_faces[self.target_faces[button]["SourceFaceAssignments"][i]]["TKButton"].config(style.media_button_on_3)

def select_input_faces(self, event, button):

try:
if event.state & 0x4 != 0:
print(f"select_input_faces 被调用: event={event}, button={button}")

if isinstance(event, str):
modifier = event
else:
if event.state & 0x4 != 0: # Ctrl键
modifier = 'ctrl'
elif event.state & 0x1 != 0:
elif event.state & 0x1 != 0: # Shift键
modifier = 'shift'
else:
modifier = 'none'
except:
modifier = event
print(f"修饰符: {modifier}")

if modifier == 'ctrl':
self._ctrl_select_face(button)
elif modifier == 'shift':
self._shift_select_faces(button)
elif modifier == 'none':
self._single_select_face(button)
elif modifier == 'auto':
self._auto_select_faces(button)

self._update_target_face_assignments()

def _ctrl_select_face(self, button):
self.source_faces[button]["ButtonState"] = not self.source_faces[button]["ButtonState"]
self.source_faces[button]["TKButton"].config(
style.media_button_on_3 if self.source_faces[button]["ButtonState"] else style.media_button_off_3
)
print(f"切换了face {button}的ButtonState为: {self.source_faces[button]['ButtonState']}")


def _select_all_faces(self):
for face in self.source_faces:
face["ButtonState"] = True
face["TKButton"].config(style.media_button_on_3)
print("执行了全选")

def _shift_select_faces(self, end_button):
start_button = next((i for i, face in enumerate(self.source_faces) if face["ButtonState"]), 0)
for i in range(min(start_button, end_button), max(start_button, end_button) + 1):
self.source_faces[i]["ButtonState"] = True
self.source_faces[i]["TKButton"].config(style.media_button_on_3)
print(f"激活了从face {start_button}到{end_button}的所有faces")

def _single_select_face(self, button):
for i, face in enumerate(self.source_faces):
if i == button:
face["ButtonState"] = not face["ButtonState"]
else:
face["ButtonState"] = False
face["TKButton"].config(style.media_button_on_3 if face["ButtonState"] else style.media_button_off_3)
print(f"切换了face {button}的ButtonState为: {self.source_faces[button]['ButtonState']}")

def _auto_select_faces(self, button):
# 保持现有的auto行为不变
self.source_faces[button]["ButtonState"] = True
self.source_faces[button]["TKButton"].config(style.media_button_on_3)

# If autoswap isnt on
# Clear all the highlights. Clear all states, excpet if a modifier is being used
# Start by turning off all the highlights on the input faces buttons
if modifier != 'auto':
for face in self.source_faces:
face["TKButton"].config(style.media_button_off_3)

# and also clear the states if not selecting multiples
if modifier == 'none':
face["ButtonState"] = False

# Toggle the state of the selected Input Face
if modifier != 'merge':
self.source_faces[button]["ButtonState"] = not self.source_faces[button]["ButtonState"]

# if shift find any other input faces and activate the state of all faces in between
if modifier == 'shift':
for i in range(button-1, self.shift_i_len-1, -1):
if self.source_faces[i]["ButtonState"]:
for j in range(i, button, 1):
self.source_faces[j]["ButtonState"] = True
break
for i in range(button+1, len(self.source_faces), 1):
if self.source_faces[i]["ButtonState"]:
for j in range(button, i, 1):
self.source_faces[j]["ButtonState"] = True
break

# Highlight all of input faces buttons that have a true state
for face in self.source_faces:
if face["ButtonState"]:
face["TKButton"].config(style.media_button_on_3)

if self.widget['PreviewModeTextSel'].get() == 'FaceLab':
self.add_action("load_target_image", face["file"])
self.image_loaded = True

# Assign all active input faces to the active target face
def _update_target_face_assignments(self):
for tface in self.target_faces:
if tface["ButtonState"]:

# Clear all of the assignments
tface["SourceFaceAssignments"] = []

# Iterate through all Input faces
temp_holder = []
for j in range(len(self.source_faces)):

# If the source face is active
if self.source_faces[j]["ButtonState"]:
for j, face in enumerate(self.source_faces):
if face["ButtonState"]:
tface["SourceFaceAssignments"].append(j)
temp_holder.append(self.source_faces[j]['Embedding'])

# do averaging
temp_holder.append(face['Embedding'])

if temp_holder:
if self.widget['MergeTextSel'].get() == 'Median':
tface['AssignedEmbedding'] = np.median(temp_holder, 0)
elif self.widget['MergeTextSel'].get() == 'Mean':
tface['AssignedEmbedding'] = np.mean(temp_holder, 0)

merge_method = self.widget['MergeTextSel'].get()
tface['AssignedEmbedding'] = np.median(temp_holder, 0) if merge_method == 'Median' else np.mean(temp_holder, 0)
self.temp_emb = tface['AssignedEmbedding']

# for k in range(512):
# self.widget['emb_vec_' + str(k)].set(tface['AssignedEmbedding'][k], False)
print(f"使用{merge_method}合并了embeddings")
break

self.add_action("target_faces", self.target_faces)
self.add_action('get_requested_video_frame', self.video_slider.get())

# latent = torch.from_numpy(self.models.calc_swapper_latent(self.source_faces[button]['Embedding'])).float().to('cuda')
# face['ptrdata'] = self.models.run_swap_stg1(latent)

print("添加了action: target_faces和get_requested_video_frame")


def populate_target_videos(self):
Expand Down