From c05102245d2a0a2161e5f71270ebad26671ead88 Mon Sep 17 00:00:00 2001 From: "TOTHEMOON\\youdo" <1023571809@qq.com> Date: Sun, 25 Aug 2024 19:49:56 +0800 Subject: [PATCH] Added the functionality for Ctrl+A to select all images, as well as the ability to delete selected images using the delete key. --- rope/GUI.py | 219 +++++++++++++++++++++++++++++++++------------------- 1 file changed, 140 insertions(+), 79 deletions(-) diff --git a/rope/GUI.py b/rope/GUI.py index 2437f845..822cc12e 100644 --- a/rope/GUI.py +++ b/rope/GUI.py @@ -10,6 +10,7 @@ import bisect import torch import torchvision +from send2trash import send2trash torchvision.disable_beta_transforms_warning() import mimetypes @@ -67,6 +68,7 @@ def __init__(self, models): self.static_widget = {} self.layer = {} + self.temp_emb = [] self.arcface_dst = np.array([[38.2946, 51.6963], [73.5318, 51.5014], [56.0252, 71.7366], [41.5493, 92.3655], [70.7299, 92.2041]], dtype=np.float32) @@ -104,10 +106,75 @@ def __init__(self, models): "Embedding": [] } self.source_faces = [] + self.bind('', self.delete_selected_face) + print("Delete键绑定已设置") + self.bind('', self.select_all_faces) # 添加这一行 + print("Delete键和Ctrl+A键绑定已设置") ##### + + def select_all_faces(self, event): + print("select_all_faces 方法被调用") + for face in self.source_faces: + face["ButtonState"] = True + face["TKButton"].config(style.media_button_on_3) + self.update_faces_canvas() + print("所有faces已被选中并处理") + self._update_target_face_assignments() + + # 更新target faces (如果需要的话) + # self.update_target_faces() + + def delete_selected_face(self, event): + print("delete_selected_face 方法被调用") + for i, face in enumerate(self.source_faces): + if face["ButtonState"]: + print(f"找到选中的face: {i}") + + # 规范化文件路径 + file_path = os.path.normpath(face["file"]) + print(f"规范化后的文件路径: {file_path}") + + # 检查文件是否存在 + if not os.path.exists(file_path): + print(f"文件不存在: {file_path}") + return + + # 将文件发送到回收站 + try: + send2trash(file_path) + print(f"已将文件发送到回收站: {file_path}") + except Exception as e: + print(f"发送文件到回收站时出错: {str(e)}") + return # 如果操作失败,我们不应继续 + + # 从列表中移除face + self.source_faces.pop(i) + print(f"已从source_faces列表中移除face {i}") + + # 更新画布,而不是重新加载 + self.update_faces_canvas() + print("更新了faces画布") + break + else: + print("没有找到选中的face") + + def update_faces_canvas(self): + self.source_faces_canvas.delete("all") # 清除画布上的所有内容 + + for i, face in enumerate(self.source_faces): + button_style = style.media_button_on_3 if face["ButtonState"] else style.media_button_off_3 + face["TKButton"] = tk.Button(self.source_faces_canvas, button_style, image=face["Image"], height=90, width=90) + face["TKButton"].bind("", lambda event, arg=i: self.select_input_faces(event, arg)) + face["TKButton"].bind("", self.source_faces_mouse_wheel) + + self.source_faces_canvas.create_window((i % 2) * 100, (i // 2) * 100, window=face["TKButton"], anchor='nw') + # print(f"重新绘制了face {i}, 选中状态: {face['ButtonState']}") + + self.static_widget['input_faces_scrollbar'].resize_scrollbar(None) + def create_gui(self): # 1 x 3 Top level grid @@ -915,6 +982,7 @@ def select_faces_path(self): self.load_input_faces() def load_input_faces(self): + print("load_input_faces 方法被调用") self.source_faces = [] self.merged_faces_canvas.delete("all") self.source_faces_canvas.delete("all") @@ -945,8 +1013,9 @@ def load_input_faces(self): self.merged_faces_canvas.configure(scrollregion = self.merged_faces_canvas.bbox("all")) self.merged_faces_canvas.xview_moveto(0) - except: - pass + print(f"加载了 {len(temp0)} 个合并的embeddings") + except Exception as e: + print(f"加载合并embeddings时出错: {str(e)}") self.shift_i_len = len(self.source_faces) @@ -954,17 +1023,13 @@ def load_input_faces(self): directory = self.json_dict["source faces"] filenames = [os.path.join(dirpath,f) for (dirpath, dirnames, filenames) in os.walk(directory) for f in filenames] - # torch.cuda.memory._record_memory_history(True, trace_alloc_max_entries=100000, trace_alloc_record_context=True) i=0 - for file in filenames: # Does not include full path - # Find all faces and ad to faces[] - # Guess File type based on extension + for file in filenames: try: file_type = mimetypes.guess_type(file)[0][:5] except: - print('Unrecognized file type:', file) + print('无法识别的文件类型:', file) else: - # Its an image if file_type == 'image': img = cv2.imread(file) @@ -989,7 +1054,7 @@ def load_input_faces(self): try: kpss = self.models.run_detect(img, max_num=1)[0] # Just one face here except IndexError: - print('Image cropped too close:', file) + print('图像裁剪过近:', file) else: face_emb, cropped_image = self.models.run_recognize(img, kpss) crop = cv2.cvtColor(cropped_image.cpu().numpy(), cv2.COLOR_BGR2RGB) @@ -1011,12 +1076,13 @@ def load_input_faces(self): self.static_widget['input_faces_scrollbar'].resize_scrollbar(None) i = i + 1 + # print(f"加载了图像: {file}") else: - print('Bad file', file) - + print('无效的文件', file) torch.cuda.empty_cache() + print(f"总共加载了 {i} 个图像文件") def find_faces(self): try: @@ -1106,90 +1172,85 @@ def toggle_found_faces_buttons_state(self, button): self.source_faces[self.target_faces[button]["SourceFaceAssignments"][i]]["TKButton"].config(style.media_button_on_3) def select_input_faces(self, event, button): - - try: - if event.state & 0x4 != 0: + print(f"select_input_faces 被调用: event={event}, button={button}") + + if isinstance(event, str): + modifier = event + else: + if event.state & 0x4 != 0: # Ctrl键 modifier = 'ctrl' - elif event.state & 0x1 != 0: + elif event.state & 0x1 != 0: # Shift键 modifier = 'shift' else: modifier = 'none' - except: - modifier = event + print(f"修饰符: {modifier}") + + if modifier == 'ctrl': + self._ctrl_select_face(button) + elif modifier == 'shift': + self._shift_select_faces(button) + elif modifier == 'none': + self._single_select_face(button) + elif modifier == 'auto': + self._auto_select_faces(button) + + self._update_target_face_assignments() + + def _ctrl_select_face(self, button): + self.source_faces[button]["ButtonState"] = not self.source_faces[button]["ButtonState"] + self.source_faces[button]["TKButton"].config( + style.media_button_on_3 if self.source_faces[button]["ButtonState"] else style.media_button_off_3 + ) + print(f"切换了face {button}的ButtonState为: {self.source_faces[button]['ButtonState']}") + + + def _select_all_faces(self): + for face in self.source_faces: + face["ButtonState"] = True + face["TKButton"].config(style.media_button_on_3) + print("执行了全选") + + def _shift_select_faces(self, end_button): + start_button = next((i for i, face in enumerate(self.source_faces) if face["ButtonState"]), 0) + for i in range(min(start_button, end_button), max(start_button, end_button) + 1): + self.source_faces[i]["ButtonState"] = True + self.source_faces[i]["TKButton"].config(style.media_button_on_3) + print(f"激活了从face {start_button}到{end_button}的所有faces") + + def _single_select_face(self, button): + for i, face in enumerate(self.source_faces): + if i == button: + face["ButtonState"] = not face["ButtonState"] + else: + face["ButtonState"] = False + face["TKButton"].config(style.media_button_on_3 if face["ButtonState"] else style.media_button_off_3) + print(f"切换了face {button}的ButtonState为: {self.source_faces[button]['ButtonState']}") + def _auto_select_faces(self, button): + # 保持现有的auto行为不变 + self.source_faces[button]["ButtonState"] = True + self.source_faces[button]["TKButton"].config(style.media_button_on_3) - # If autoswap isnt on - # Clear all the highlights. Clear all states, excpet if a modifier is being used - # Start by turning off all the highlights on the input faces buttons - if modifier != 'auto': - for face in self.source_faces: - face["TKButton"].config(style.media_button_off_3) - - # and also clear the states if not selecting multiples - if modifier == 'none': - face["ButtonState"] = False - - # Toggle the state of the selected Input Face - if modifier != 'merge': - self.source_faces[button]["ButtonState"] = not self.source_faces[button]["ButtonState"] - - # if shift find any other input faces and activate the state of all faces in between - if modifier == 'shift': - for i in range(button-1, self.shift_i_len-1, -1): - if self.source_faces[i]["ButtonState"]: - for j in range(i, button, 1): - self.source_faces[j]["ButtonState"] = True - break - for i in range(button+1, len(self.source_faces), 1): - if self.source_faces[i]["ButtonState"]: - for j in range(button, i, 1): - self.source_faces[j]["ButtonState"] = True - break - - # Highlight all of input faces buttons that have a true state - for face in self.source_faces: - if face["ButtonState"]: - face["TKButton"].config(style.media_button_on_3) - - if self.widget['PreviewModeTextSel'].get() == 'FaceLab': - self.add_action("load_target_image", face["file"]) - self.image_loaded = True - - # Assign all active input faces to the active target face + def _update_target_face_assignments(self): for tface in self.target_faces: if tface["ButtonState"]: - - # Clear all of the assignments tface["SourceFaceAssignments"] = [] - - # Iterate through all Input faces temp_holder = [] - for j in range(len(self.source_faces)): - - # If the source face is active - if self.source_faces[j]["ButtonState"]: + for j, face in enumerate(self.source_faces): + if face["ButtonState"]: tface["SourceFaceAssignments"].append(j) - temp_holder.append(self.source_faces[j]['Embedding']) - - # do averaging + temp_holder.append(face['Embedding']) + if temp_holder: - if self.widget['MergeTextSel'].get() == 'Median': - tface['AssignedEmbedding'] = np.median(temp_holder, 0) - elif self.widget['MergeTextSel'].get() == 'Mean': - tface['AssignedEmbedding'] = np.mean(temp_holder, 0) - + merge_method = self.widget['MergeTextSel'].get() + tface['AssignedEmbedding'] = np.median(temp_holder, 0) if merge_method == 'Median' else np.mean(temp_holder, 0) self.temp_emb = tface['AssignedEmbedding'] - - # for k in range(512): - # self.widget['emb_vec_' + str(k)].set(tface['AssignedEmbedding'][k], False) + print(f"使用{merge_method}合并了embeddings") break - + self.add_action("target_faces", self.target_faces) self.add_action('get_requested_video_frame', self.video_slider.get()) - - # latent = torch.from_numpy(self.models.calc_swapper_latent(self.source_faces[button]['Embedding'])).float().to('cuda') - # face['ptrdata'] = self.models.run_swap_stg1(latent) - + print("添加了action: target_faces和get_requested_video_frame") def populate_target_videos(self):