diff --git a/ecleankernel/file.py b/ecleankernel/file.py index c552cbd..9008926 100644 --- a/ecleankernel/file.py +++ b/ecleankernel/file.py @@ -111,7 +111,10 @@ def __init__(self, super().__init__(path, KernelFileType.KERNEL) self.internal_version = self.read_internal_version() - def decompress_raw(self, f: typing.IO[bytes]) -> bytes: + def decompress_raw(self, + f: typing.IO[bytes], + size: int = -1 + ) -> bytes: magic_dict = { b'\x1f\x8b\x08': 'gzip', b'\x42\x5a\x68': 'bz2', @@ -122,6 +125,8 @@ def decompress_raw(self, f: typing.IO[bytes]) -> bytes: b'\x89\x4c\x5a\x4f\x00\x0d\x0a\x1a\x0a': 'lzo', } maxlen = max(len(x) for x in magic_dict) + if size > 0 and maxlen > size: + return f.read(size) with autorewind(f): header = f.read(maxlen) for magic, comp in magic_dict.items(): @@ -137,21 +142,15 @@ def decompress_raw(self, f: typing.IO[bytes]) -> bytes: # Technically a redundant import, this is just # to make your IDE happy :) import zstandard - reader = zstandard.ZstdDecompressor().stream_reader(f) - decomp = b'' - while True: - chunk = reader.read(1024 * 1024) - if not chunk: - break - decomp += chunk - return decomp + decompress = zstandard.ZstdDecompressor().decompressobj() + return decompress(f.read(size)) elif comp == 'lzma': # Using .decompress() causes an error because of # no end-of-stream marker - return LZMADecompressor().decompress(f.read()) + return LZMADecompressor().decompress(f.read(size)) else: - return getattr(mod, 'decompress')(f.read()) - return f.read() + return getattr(mod, 'decompress')(f.read(size)) + return f.read(size) def read_internal_version(self) -> str: """Read version from the kernel file""" @@ -199,11 +198,12 @@ def read_version_from_bzimage(self, def read_version_from_raw(self, f: typing.IO[bytes], + size: int = -1, ) -> typing.Optional[bytes]: """Read version from raw kernel image""" # check if it's compressed first - b = self.decompress_raw(f) + b = self.decompress_raw(f, size) # unlike with bzImage, the raw kernel binary has no header # that includes the version, so we parse the version message # that appears on boot @@ -226,6 +226,14 @@ def read_version_from_efi(self, buf = f.read(0x40) if len(buf) != 0x40 or buf[:2] != b"MZ": return None + + # handle EFI zboot image + # see kernel source code drivers/firmware/efi/libstub/zboot-header.S + if buf[4:8] == b"zimg": + offset, size = struct.unpack_from(" None: + """Write an EFI zboot kernel image at `path`, with `version_line`""" + # generate a compressed image as our payload + write_compress(path, version_line) + b = path.read_bytes() + + with open(path, "wb") as f: + f.write(b"MZ\0\0zimg") + f.write(struct.pack(" None: self.td = tempfile.TemporaryDirectory() @@ -147,6 +164,13 @@ def test_read_internal_version_efistub_uname_nowhitespace(self) -> None: KernelImage(path).read_internal_version(), "1.2.3") + def test_read_internal_version_efi_zboot(self) -> None: + path = Path(self.td.name) / "vmlinuz" + write_efi_zboot(path, b"Linux version 1.2.3 built on test") + self.assertEqual( + KernelImage(path).read_internal_version(), + "1.2.3") + def test_very_short(self) -> None: path = Path(self.td.name) / 'vmlinuz' with open(path, 'wb') as f: