diff --git a/python_modules/demarshal.py b/python_modules/demarshal.py index da87d44..7e73985 100644 --- a/python_modules/demarshal.py +++ b/python_modules/demarshal.py @@ -101,7 +101,7 @@ def write_parser_helpers(writer): writer.variable_def("uint64_t", "offset") writer.variable_def("parse_func_t", "parse") writer.variable_def("void **", "dest") - writer.variable_def("uint32_t", "nelements") + writer.variable_def("uint64_t", "nelements") writer.end_block(semicolon=True) def write_read_primitive(writer, start, container, name, scope): @@ -186,7 +186,7 @@ def write_validate_switch_member(writer, mprefix, container, switch_member, scop all_as_extra_size = m.is_extra_size() and want_extra_size if not want_mem_size and all_as_extra_size and not scope.variable_defined(item.mem_size()): - scope.variable_def("uint32_t", item.mem_size()) + scope.variable_def("uint64_t", item.mem_size()) sub_want_mem_size = want_mem_size or all_as_extra_size sub_want_extra_size = want_extra_size and not all_as_extra_size @@ -219,7 +219,7 @@ def write_validate_struct_function(writer, struct): scope = writer.function(validate_function, "static intptr_t", "uint8_t *message_start, uint8_t *message_end, uint64_t offset, SPICE_GNUC_UNUSED int minor") scope.variable_def("uint8_t *", "start = message_start + offset") scope.variable_def("SPICE_GNUC_UNUSED uint8_t *", "pos") - scope.variable_def("size_t", "mem_size", "nw_size") + scope.variable_def("uint64_t", "mem_size", "nw_size") num_pointers = struct.get_num_pointers() if num_pointers != 0: scope.variable_def("SPICE_GNUC_UNUSED intptr_t", "ptr_size") @@ -236,7 +236,7 @@ def write_validate_struct_function(writer, struct): writer.newline() writer.comment("Check if struct fits in reported side").newline() - writer.error_check("start + nw_size > message_end") + writer.error_check("nw_size > (uintptr_t) (message_end - start)") writer.statement("return mem_size") @@ -264,26 +264,26 @@ def write_validate_pointer_item(writer, container, item, scope, parent_scope, st # if array, need no function check if target_type.is_array(): - writer.error_check("message_start + %s >= message_end" % v) + writer.error_check("%s >= (uintptr_t) (message_end - message_start)" % v) assert target_type.element_type.is_primitive() array_item = ItemInfo(target_type, "%s__array" % item.prefix, start) - scope.variable_def("uint32_t", array_item.nw_size()) + scope.variable_def("uint64_t", array_item.nw_size()) # don't create a variable that isn't used, fixes -Werror=unused-but-set-variable need_mem_size = want_mem_size or ( want_extra_size and not item.member.has_attr("chunk") and not target_type.is_cstring_length()) if need_mem_size: - scope.variable_def("uint32_t", array_item.mem_size()) + scope.variable_def("uint64_t", array_item.mem_size()) if target_type.is_cstring_length(): writer.assign(array_item.nw_size(), "spice_strnlen((char *)message_start + %s, message_end - (message_start + %s))" % (v, v)) writer.error_check("*(message_start + %s + %s) != 0" % (v, array_item.nw_size())) else: write_validate_array_item(writer, container, array_item, scope, parent_scope, start, True, want_mem_size=need_mem_size, want_extra_size=False) - writer.error_check("message_start + %s + %s > message_end" % (v, array_item.nw_size())) + writer.error_check("%s + %s > (uintptr_t) (message_end - message_start)" % (v, array_item.nw_size())) if want_extra_size: if item.member and item.member.has_attr("chunk"): @@ -321,11 +321,11 @@ def write_validate_array_item(writer, container, item, scope, parent_scope, star nelements = "%s__nbytes" %(item.prefix) real_nelements = "%s__nelements" %(item.prefix) if not parent_scope.variable_defined(real_nelements): - parent_scope.variable_def("uint32_t", real_nelements) + parent_scope.variable_def("uint64_t", real_nelements) else: nelements = "%s__nelements" %(item.prefix) if not parent_scope.variable_defined(nelements): - parent_scope.variable_def("uint32_t", nelements) + parent_scope.variable_def("uint64_t", nelements) if array.is_constant_length(): writer.assign(nelements, array.size) @@ -420,10 +420,10 @@ def write_validate_array_item(writer, container, item, scope, parent_scope, star element_nw_size = element_item.nw_size() element_mem_size = element_item.mem_size() element_extra_size = element_item.extra_size() - scope.variable_def("uint32_t", element_nw_size) - scope.variable_def("uint32_t", element_mem_size) + scope.variable_def("uint64_t", element_nw_size) + scope.variable_def("uint64_t", element_mem_size) if want_extra_size: - scope.variable_def("uint32_t", element_extra_size) + scope.variable_def("uint64_t", element_extra_size) if want_nw_size: writer.assign(nw_size, 0) @@ -556,7 +556,7 @@ def write_validate_container(writer, prefix, container, start, parent_scope, wan sub_want_nw_size = want_nw_size and not m.is_fixed_nw_size() sub_want_mem_size = m.is_extra_size() and want_mem_size sub_want_extra_size = not m.is_extra_size() and m.contains_extra_size() - defs = ["size_t"] + defs = ["uint64_t"] name = prefix_m(prefix, m) if sub_want_nw_size: @@ -697,7 +697,7 @@ def read_array_len(writer, prefix, array, dest, scope, is_ptr): if dest.is_toplevel() and scope.variable_defined(nelements): return nelements # Already there for toplevel, need not recalculate element_type = array.element_type - scope.variable_def("uint32_t", nelements) + scope.variable_def("uint64_t", nelements) if array.is_constant_length(): writer.assign(nelements, array.size) elif array.is_identifier_length(): @@ -1053,9 +1053,9 @@ def write_msg_parser(writer, message): parent_scope.variable_def("SPICE_GNUC_UNUSED uint8_t *", "pos") parent_scope.variable_def("uint8_t *", "start = message_start") parent_scope.variable_def("uint8_t *", "data = NULL") - parent_scope.variable_def("size_t", "nw_size") + parent_scope.variable_def("uint64_t", "nw_size") if want_mem_size: - parent_scope.variable_def("size_t", "mem_size") + parent_scope.variable_def("uint64_t", "mem_size") if not message.has_attr("nocopy"): parent_scope.variable_def("uint8_t *", "in", "end") num_pointers = message.get_num_pointers() @@ -1073,7 +1073,7 @@ def write_msg_parser(writer, message): writer.newline() writer.comment("Check if message fits in reported side").newline() - with writer.block("if (start + nw_size > message_end)"): + with writer.block("if (nw_size > (uintptr_t) (message_end - start))"): writer.statement("return NULL") writer.newline().comment("Validated extents and calculated size").newline() @@ -1084,7 +1084,7 @@ def write_msg_parser(writer, message): writer.assign("*size", "message_end - message_start") writer.assign("*free_message", "nofree") else: - writer.assign("data", "(uint8_t *)malloc(mem_size)") + writer.assign("data", "(uint8_t *)(mem_size > UINT32_MAX ? NULL : malloc(mem_size))") writer.error_check("data == NULL") writer.assign("end", "data + %s" % (msg_sizeof)) writer.assign("in", "start").newline()