summaryrefslogtreecommitdiff
path: root/lib/compression/tests/scripts/generate-windows-test-vectors.py
blob: b5da5b830bcae97bb179db242cb2e9bc1f982cea (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
# Generate test vectors for Windows LZ77 Huffman compression.
#
# Copyright (c) 2022 Catalyst IT
#
# GPLv3+.
#
# This uses the Python ctypes module to access the lower level RTL
# compression functions.

import sys
import argparse
from ctypes import create_string_buffer, byref, windll
from ctypes.wintypes import USHORT, ULONG, LONG, PULONG, LPVOID, CHAR
NTSTATUS = LONG


METHODS = {
    'LZNT1': 2,
    'XPRESS_PLAIN': 3,
    'XPRESS_HUFF': 4,
    '2': 2,
    '3': 3,
    '4': 4
}


class RtlError(Exception):
    pass


def ntstatus_check(status, f, args):
    # 0x117 is STATUS_BUFFER_ALL_ZEROS
    status &= 0xffffffff
    if status in (0, 0x117):
        return status
    msg = {
        0xC0000023: "buffer too small",
        0xC0000242: "bad compression data",
    }.get(status, '')

    raise RtlError(f'NTSTATUS: {status:08X} {msg}')


def wrap(f, result, *args):
    f.restype = result
    f.argtypes = args
    f.errcheck = ntstatus_check
    return f


CompressBuffer = wrap(windll.ntdll.RtlCompressBuffer, NTSTATUS,
                      USHORT, LPVOID, ULONG, LPVOID, ULONG, ULONG, PULONG,
                      LPVOID)


GetCompressionWorkSpaceSize = wrap(windll.ntdll.RtlGetCompressionWorkSpaceSize,
                                   NTSTATUS,
                                   USHORT, PULONG, PULONG)


DecompressBufferEx = wrap(windll.ntdll.RtlDecompressBufferEx,
                          NTSTATUS,
                          USHORT, LPVOID, ULONG, LPVOID, ULONG, PULONG, LPVOID)


def compress(data, format, effort=0):
    flags = USHORT(format | effort)
    workspace_size = ULONG(0)
    fragment_size = ULONG(0)
    comp_len = ULONG(0)
    GetCompressionWorkSpaceSize(flags,
                                byref(workspace_size),
                                byref(fragment_size))
    workspace = create_string_buffer(workspace_size.value)
    output_len = len(data) * 9 // 8 + 260
    output_buf = bytearray(output_len)
    CompressBuffer(flags,
                   (CHAR * 1).from_buffer(data), len(data),
                   (CHAR * 1).from_buffer(output_buf), output_len,
                   4096,
                   byref(comp_len),
                   workspace)
    return output_buf[:comp_len.value]


def decompress(data, format, target_size=None):
    flags = USHORT(format)
    workspace_size = ULONG(0)
    fragment_size = ULONG(0)
    decomp_len = ULONG(0)
    GetCompressionWorkSpaceSize(flags,
                                byref(workspace_size),
                                byref(fragment_size))
    workspace = create_string_buffer(workspace_size.value)
    if target_size is None:
        output_len = len(data) * 10
    else:
        output_len = target_size
    output_buf = bytearray(output_len)

    DecompressBufferEx(format,
                       (CHAR * 1).from_buffer(output_buf), len(output_buf),
                       (CHAR * 1).from_buffer(data), len(data),
                       byref(decomp_len),
                       workspace)
    return output_buf[:decomp_len.value]


def main():
    if sys.getwindowsversion().major < 7:
        print("this probably won't work on your very old version of Windows\n"
              "but we'll try anyway!", file=sys.stderr)

    parser = argparse.ArgumentParser()
    parser.add_argument('-d', '--decompress', action='store_true',
                        help='decompress instead of compress')
    parser.add_argument('-m', '--method', default='XPRESS_HUFF',
                        choices=list(METHODS.keys()),
                        help='use this compression method')
    parser.add_argument('-e', '--extra-effort', action='store_true',
                        help='use extra effort to compress')

    parser.add_argument('-s', '--decompressed-size', type=int,
                        help=('decompress to this size '
                              '(required for XPRESS_HUFF'))

    parser.add_argument('-o', '--output',
                        help='write to this file')
    parser.add_argument('-i', '--input',
                        help='read data from this file')

    args = parser.parse_args()

    method = METHODS[args.method]

    if all((args.decompress,
            args.decompressed_size is None,
            method == 4)):
        print("a size is required for XPRESS_HUFF decompression")
        sys.exit(1)

    with open(args.input, 'rb') as f:
        data = bytearray(f.read())

    if args.decompress:
        output = decompress(data, method, args.decompressed_size)
    else:
        effort = 1 if args.extra_effort else 0
        output = compress(data, method, effort)

    with open(args.output, 'wb') as f:
        f.write(output)


main()