Skip to content

Commit

Permalink
Fix output gallery for csv format inc. VAE & LoRA (#1591)
Browse files Browse the repository at this point in the history
  • Loading branch information
one-lithe-rune authored Jun 24, 2023
1 parent 5ce6001 commit e3ab844
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 22 deletions.
5 changes: 4 additions & 1 deletion apps/stable_diffusion/src/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,8 +810,11 @@ def save_output_img(output_img, img_seed, extra_info={}):

new_entry.update(extra_info)

with open(csv_path, "a", encoding="utf-8") as csv_obj:
csv_mode = "a" if os.path.isfile(csv_path) else "w"
with open(csv_path, csv_mode, encoding="utf-8") as csv_obj:
dictwriter_obj = DictWriter(csv_obj, fieldnames=list(new_entry.keys()))
if csv_mode == "w":
dictwriter_obj.writeheader()
dictwriter_obj.writerow(new_entry)
csv_obj.close()

Expand Down
40 changes: 27 additions & 13 deletions apps/stable_diffusion/web/utils/metadata/csv_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,35 @@ def has_csv(image_filename: str) -> bool:
return os.path.exists(csv_path(image_filename))


def parse_csv(image_filename: str):
# We use a reader instead of a DictReader here for images_details.csv files due to the lack of
# headers, and then match up the return list for each row with our guess at which column format
# the file is using.

def matching_filename(image_filename: str, row):
# we assume the final column of the csv has the original filename with full path and match that
# against the image_filename. We then exclude the filename from the output, hence the -1's.
# against the image_filename if we are given a list. Otherwise we assume a dict and and take
# the value of the OUTPUT key
return os.path.basename(image_filename) in (
row[-1] if isinstance(row, list) else row["OUTPUT"]
)


def parse_csv(image_filename: str):
csv_filename = csv_path(image_filename)

matches = [
humanize(row)
for row in csv.reader(open(csv_filename, "r", newline=""))
if row
and humanizable(row)
and os.path.basename(image_filename) in row[-1]
]
with open(csv_filename, "r", newline="") as csv_file:
# We use a reader or DictReader here for images_details.csv depending on whether we think it
# has headers or not. Having headers means less guessing of the format.
has_header = csv.Sniffer().has_header(csv_file.read(2048))
csv_file.seek(0)

reader = (
csv.DictReader(csv_file) if has_header else csv.reader(csv_file)
)

matches = [
# we rely on humanize and humanizable to work out the parsing of the individual .csv rows
humanize(row)
for row in reader
if row
and (has_header or humanizable(row))
and matching_filename(image_filename, row)
]

return matches[0] if matches else {}
32 changes: 24 additions & 8 deletions apps/stable_diffusion/web/utils/metadata/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,22 @@
},
}

PARAMS_FORMAT_LONGEST = PARAMS_FORMATS[max(PARAMS_FORMATS.keys())]
PARAMS_FORMAT_CURRENT = {
"VARIANT": "Model",
"VAE": "VAE",
"LORA": "LoRA",
"SCHEDULER": "Sampler",
"PROMPT": "Prompt",
"NEG_PROMPT": "Negative prompt",
"SEED": "Seed",
"CFG_SCALE": "CFG scale",
"PRECISION": "Precision",
"STEPS": "Steps",
"HEIGHT": "Height",
"WIDTH": "Width",
"MAX_LENGTH": "Max Length",
"OUTPUT": "Filename",
}


def compact(metadata: dict) -> dict:
Expand Down Expand Up @@ -97,19 +112,20 @@ def humanize(metadata: dict | list[str], includes_filename=True) -> dict:
)

# For dictionaries we try to use the matching length parameter format if
# available, otherwise we use the longest. Then we swap keys in the
# metadata that match keys in the format for the friendlier name that we
# have set in the format value
# available, otherwise we just use the current format which is assumed to
# have everything currently known about. Then we swap keys in the metadata
# that match keys in the format for the friendlier name that we have set
# in the format value
if isinstance(metadata, dict):
if humanizable(metadata, includes_filename):
format = PARAMS_FORMATS[lookup_key]
else:
format = PARAMS_FORMAT_LONGEST
format = PARAMS_FORMAT_CURRENT

return {
format[key]: value
for (key, value) in metadata.items()
if key in format.keys()
format[key]: metadata[key]
for key in format.keys()
if key in metadata.keys() and metadata[key]
}

raise TypeError("Can only humanize parameter lists or dictionaries")

0 comments on commit e3ab844

Please sign in to comment.