Coverage for dataclasses_struct/dataclass.py: 98%
229 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-15 10:33 +1200
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-15 10:33 +1200
1import dataclasses
2import sys
3from collections.abc import Generator, Iterator
4from struct import Struct
5from types import GenericAlias
6from typing import (
7 Annotated,
8 Any,
9 Callable,
10 Generic,
11 Literal,
12 Protocol,
13 TypedDict,
14 TypeVar,
15 Union,
16 get_args,
17 get_origin,
18 get_type_hints,
19 overload,
20)
22from ._typing import TypeGuard, Unpack, dataclass_transform
23from .field import Field, builtin_fields
24from .types import PadAfter, PadBefore
27def _separate_padding_from_annotation_args(args) -> tuple[int, int, object]:
28 pad_before = pad_after = 0
29 extra_arg = None # should be Field or integer for bytes/list types
30 for arg in args:
31 if isinstance(arg, PadBefore):
32 pad_before += arg.size
33 elif isinstance(arg, PadAfter):
34 pad_after += arg.size
35 elif extra_arg is not None:
36 raise TypeError(f"too many annotations: {arg}")
37 else:
38 extra_arg = arg
40 return pad_before, pad_after, extra_arg
43def _format_str_with_padding(fmt: str, pad_before: int, pad_after: int) -> str:
44 return "".join(
45 (
46 (f"{pad_before}x" if pad_before else ""),
47 fmt,
48 (f"{pad_after}x" if pad_after else ""),
49 )
50 )
53T = TypeVar("T")
56_SIZE_BYTEORDER_MODE_CHAR: dict[tuple[str, str], str] = {
57 ("native", "native"): "@",
58 ("std", "native"): "=",
59 ("std", "little"): "<",
60 ("std", "big"): ">",
61 ("std", "network"): "!",
62}
63_MODE_CHAR_SIZE_BYTEORDER: dict[str, tuple[str, str]] = {
64 v: k for k, v in _SIZE_BYTEORDER_MODE_CHAR.items()
65}
68class _DataclassStructInternal(Generic[T]):
69 struct: Struct
70 cls: type[T]
71 _fieldnames: list[str]
72 _fields: list[tuple[Field[Any], type]]
74 @property
75 def format(self) -> str:
76 return self.struct.format
78 @property
79 def size(self) -> int:
80 return self.struct.size
82 @property
83 def mode(self) -> str:
84 return self.format[0]
86 def __init__(
87 self,
88 fmt: str,
89 cls: type,
90 fieldnames: list[str],
91 fields: list[tuple[Field[Any], type]],
92 ):
93 self.struct = Struct(fmt)
94 self.cls = cls
95 self._fieldnames = fieldnames
96 self._fields = fields
98 def _flattened_attrs(self, outer_self: T) -> list[Any]:
99 """
100 Returns a list of all attributes of `outer_self`, including those of
101 any nested structs.
102 """
103 attrs: list[Any] = []
104 for fieldname in self._fieldnames:
105 attr = getattr(outer_self, fieldname)
106 self._flatten_attr(attrs, attr)
107 return attrs
109 @staticmethod
110 def _flatten_attr(attrs: list[Any], attr: object) -> None:
111 if is_dataclass_struct(attr):
112 attrs.extend(attr.__dataclass_struct__._flattened_attrs(attr))
113 elif isinstance(attr, list):
114 for sub_attr in attr:
115 _DataclassStructInternal._flatten_attr(attrs, sub_attr)
116 else:
117 attrs.append(attr)
119 def pack(self, obj: T) -> bytes:
120 return self.struct.pack(*self._flattened_attrs(obj))
122 def _arg_generator(self, args: Iterator) -> Generator:
123 for field, fieldtype in self._fields:
124 yield from _DataclassStructInternal._generate_args_recursively(
125 args, field, fieldtype
126 )
128 @staticmethod
129 def _generate_args_recursively(
130 args: Iterator,
131 field: Field[Any],
132 field_type: type,
133 ) -> Generator:
134 if is_dataclass_struct(field_type):
135 yield field_type.__dataclass_struct__._init_from_args(args)
136 elif isinstance(field, _FixedLengthArrayField):
137 items: list = []
138 for _ in range(field.n):
139 items.extend(
140 _DataclassStructInternal._generate_args_recursively(
141 args, field.item_field, field.item_type
142 )
143 )
144 yield items
145 else:
146 yield field_type(next(args))
148 def _init_from_args(self, args: Iterator) -> T:
149 """
150 Returns an instance of self.cls, consuming args
151 """
152 return self.cls(*self._arg_generator(args))
154 def unpack(self, data: bytes) -> T:
155 return self._init_from_args(iter(self.struct.unpack(data)))
158class DataclassStructProtocol(Protocol):
159 __dataclass_struct__: _DataclassStructInternal
161 @classmethod
162 def from_packed(cls: type[T], data: bytes) -> T: ...
164 def pack(self) -> bytes: ...
167@overload
168def is_dataclass_struct(
169 obj: type,
170) -> TypeGuard[type[DataclassStructProtocol]]: ...
173@overload
174def is_dataclass_struct(obj: object) -> TypeGuard[DataclassStructProtocol]: ...
177def is_dataclass_struct(
178 obj: Union[type, object],
179) -> Union[
180 TypeGuard[DataclassStructProtocol],
181 TypeGuard[type[DataclassStructProtocol]],
182]:
183 """
184 Returns True if obj is a class that has been decorated with
185 dataclasses_struct.dataclass or an instance of one.
186 """
187 return (
188 dataclasses.is_dataclass(obj)
189 and hasattr(obj, "__dataclass_struct__")
190 and isinstance(obj.__dataclass_struct__, _DataclassStructInternal)
191 )
194def get_struct_size(cls_or_obj) -> int:
195 """
196 Returns the size of the packed representation of the struct in bytes.
197 Accepts either a class or an instance of a dataclass_struct.
198 """
199 if not is_dataclass_struct(cls_or_obj):
200 raise TypeError(f"{cls_or_obj} is not a dataclass_struct")
201 return cls_or_obj.__dataclass_struct__.size
204class _BytesField(Field[bytes]):
205 field_type = bytes
207 def __init__(self, n: object):
208 if not isinstance(n, int) or n < 1:
209 raise ValueError("bytes length must be positive non-zero int")
211 self.n = n
213 def format(self) -> str:
214 return f"{self.n}s"
216 def validate_default(self, val: bytes) -> None:
217 if len(val) > self.n:
218 raise ValueError(f"bytes cannot be longer than {self.n} bytes")
220 def __repr__(self) -> str:
221 return f"{super().__repr__()}({self.n})"
224class _NestedField(Field):
225 field_type: type[DataclassStructProtocol]
227 def __init__(self, cls: type[DataclassStructProtocol]):
228 self.field_type = cls
230 def format(self) -> str:
231 # Return the format without the byteorder specifier at the beginning
232 return self.field_type.__dataclass_struct__.format[1:]
235class _FixedLengthArrayField(Field[list]):
236 field_type = list
238 def __init__(self, item_type_annotation: Any, mode: str, n: object):
239 if not isinstance(n, int) or n < 1:
240 raise ValueError(
241 "fixed-length array length must be positive non-zero int"
242 )
244 self.item_field, self.item_type, self.pad_before, self.pad_after = (
245 _resolve_field(item_type_annotation, mode)
246 )
247 self.n = n
248 self.is_native = self.item_field.is_native
249 self.is_std = self.item_field.is_std
251 def format(self) -> str:
252 fmt = _format_str_with_padding(
253 self.item_field.format(),
254 self.pad_before,
255 self.pad_after,
256 )
257 return fmt * self.n
259 def __repr__(self) -> str:
260 return f"{super().__repr__()}({self.item_field!r}, {self.n})"
263def _validate_modes_match(mode: str, nested_mode: str) -> None:
264 if mode != nested_mode:
265 size, byteorder = _MODE_CHAR_SIZE_BYTEORDER[nested_mode]
266 exp_size, exp_byteorder = _MODE_CHAR_SIZE_BYTEORDER[mode]
267 msg = (
268 "byteorder and size of nested dataclass-struct does not "
269 f"match that of container (expected '{exp_size}' size and "
270 f"'{exp_byteorder}' byteorder, got '{size}' size and "
271 f"'{byteorder}' byteorder)"
272 )
273 raise TypeError(msg)
276def _resolve_field(
277 annotation: Any,
278 mode: str,
279) -> tuple[Field[Any], type, int, int]:
280 """
281 Returns 4-tuple of:
282 * field
283 * type
284 * number of padding bytes before
285 * number of padding bytes after
287 Valid type annotations are:
289 1. <bool | int | float | bytes> | Annotated[<bool | int | float | bytes>, <padding>]
291 Supported builtin types.
293 2. Annotated[<bool | int | float | bytes>, Field(...), <padding>]
295 (These are the types defined in dataclasses_struct.types e.g. U32, F32).
297 3. <dataclasses_struct class> | Annotated[<dataclasses_struct class>, <padding>]
299 Must have the same size and byteorder as the container.
301 4. Annotated[bytes, <n>, <padding>]
303 Where <n> is >0.
305 5. Annotated[list[<type>], <n>, <padding>]
307 Where <n> is >0 and <type> is one of the above.
309 <padding> is an optional mixture of PadBefore and PadAfter annotations,
310 which may be repeated. E.g.
312 Annotated[int, PadBefore(5), PadAfter(2), PadBefore(3)]
313 """ # noqa: E501
315 if get_origin(annotation) == Annotated:
316 type_, *args = get_args(annotation)
317 pad_before, pad_after, annotation_arg = (
318 _separate_padding_from_annotation_args(args)
319 )
320 else:
321 pad_before = pad_after = 0
322 type_ = annotation
323 annotation_arg = None
325 field: Field[Any]
326 if annotation_arg is None:
327 if get_origin(type_) is list:
328 msg = (
329 "list types must be marked as a fixed-length using "
330 "Annotated, ex: Annotated[list[int], 5]"
331 )
332 raise TypeError(msg)
334 # Must be either a nested type or one of the supported builtins
335 if is_dataclass_struct(type_):
336 _validate_modes_match(mode, type_.__dataclass_struct__.mode)
337 field = _NestedField(type_)
338 else:
339 opt_field = builtin_fields.get(type_)
340 if opt_field is None:
341 raise TypeError(f"type not supported: {annotation}")
342 field = opt_field
343 elif isinstance(annotation_arg, Field):
344 field = annotation_arg
345 elif get_origin(type_) is list:
346 item_annotations = get_args(type_)
347 assert len(item_annotations) == 1
348 field = _FixedLengthArrayField(
349 item_annotations[0], mode, annotation_arg
350 )
351 elif issubclass(type_, bytes):
352 field = _BytesField(annotation_arg)
353 else:
354 raise TypeError(f"invalid field annotation: {annotation!r}")
356 return field, type_, pad_before, pad_after
359def _validate_and_parse_field(
360 cls: type,
361 name: str,
362 field_type: type,
363 is_native: bool,
364 validate_defaults: bool,
365 mode: str,
366) -> tuple[str, Field, type]:
367 field, type_, pad_before, pad_after = _resolve_field(field_type, mode)
369 if is_native:
370 if not field.is_native:
371 raise TypeError(
372 f"field {field} only supported in standard size mode"
373 )
374 elif not field.is_std:
375 raise TypeError(f"field {field} only supported in native size mode")
377 if validate_defaults and hasattr(cls, name):
378 val = getattr(cls, name)
379 if not isinstance(field.field_type, GenericAlias) and not isinstance(
380 val, field.field_type
381 ):
382 raise TypeError(
383 "invalid type for field: expected "
384 f"{field.field_type} got {type(val)}"
385 )
386 field.validate_default(val)
388 return (
389 _format_str_with_padding(field.format(), pad_before, pad_after),
390 field,
391 type_,
392 )
395def _make_pack_method() -> Callable:
396 func = """
397def pack(self) -> bytes:
398 '''Pack to bytes using struct.pack.'''
399 return self.__dataclass_struct__.pack(self)
400"""
402 scope: dict[str, Any] = {}
403 exec(func, {}, scope)
404 return scope["pack"]
407def _make_unpack_method(cls: type) -> classmethod:
408 func = """
409def from_packed(cls, data: bytes) -> cls_type:
410 '''Unpack from bytes.'''
411 return cls.__dataclass_struct__.unpack(data)
412"""
414 scope: dict[str, Any] = {"cls_type": cls}
415 exec(func, {}, scope)
416 return classmethod(scope["from_packed"])
419def _make_class(
420 cls: type,
421 mode: str,
422 is_native: bool,
423 validate_defaults: bool,
424 dataclass_kwargs,
425) -> type[DataclassStructProtocol]:
426 cls_annotations = get_type_hints(cls, include_extras=True)
427 struct_format = [mode]
428 fields = []
429 for name, field in cls_annotations.items():
430 fmt, field, type_ = _validate_and_parse_field(
431 cls,
432 name=name,
433 field_type=field,
434 is_native=is_native,
435 validate_defaults=validate_defaults,
436 mode=mode,
437 )
438 struct_format.append(fmt)
439 fields.append((field, type_))
441 setattr( # noqa: B010
442 cls,
443 "__dataclass_struct__",
444 _DataclassStructInternal(
445 "".join(struct_format),
446 cls,
447 list(cls_annotations.keys()),
448 fields,
449 ),
450 )
451 setattr(cls, "pack", _make_pack_method()) # noqa: B010
452 setattr(cls, "from_packed", _make_unpack_method(cls)) # noqa: B010
454 return dataclasses.dataclass(cls, **dataclass_kwargs)
457class _DataclassKwargsPre310(TypedDict, total=False):
458 init: bool
459 repr: bool
460 eq: bool
461 order: bool
462 unsafe_hash: bool
463 frozen: bool
466if sys.version_info >= (3, 10):
468 class DataclassKwargs(_DataclassKwargsPre310, total=False):
469 match_args: bool
470 kw_only: bool
471else:
473 class DataclassKwargs(_DataclassKwargsPre310, total=False):
474 pass
477@overload
478def dataclass_struct(
479 *,
480 size: Literal["native"] = "native",
481 byteorder: Literal["native"] = "native",
482 validate_defaults: bool = True,
483 **dataclass_kwargs: Unpack[DataclassKwargs],
484) -> Callable[[type], type]: ...
487@overload
488def dataclass_struct(
489 *,
490 size: Literal["std"],
491 byteorder: Literal["native", "big", "little", "network"] = "native",
492 validate_defaults: bool = True,
493 **dataclass_kwargs: Unpack[DataclassKwargs],
494) -> Callable[[type], type]: ...
497@dataclass_transform()
498def dataclass_struct(
499 *,
500 size: Literal["native", "std"] = "native",
501 byteorder: Literal["native", "big", "little", "network"] = "native",
502 validate_defaults: bool = True,
503 **dataclass_kwargs: Unpack[DataclassKwargs],
504) -> Callable[[type], type]:
505 is_native = size == "native"
506 if is_native:
507 if byteorder != "native":
508 raise ValueError("'native' size requires 'native' byteorder")
509 elif size != "std":
510 raise ValueError(f"invalid size: {size}")
511 if byteorder not in ("native", "big", "little", "network"):
512 raise ValueError(f"invalid byteorder: {byteorder}")
514 for kwarg in ("slots", "weakref_slot"):
515 if dataclass_kwargs.get(kwarg):
516 msg = f"dataclass '{kwarg}' keyword argument is not supported"
517 raise ValueError(msg)
519 def decorator(cls: type) -> type:
520 return _make_class(
521 cls,
522 mode=_SIZE_BYTEORDER_MODE_CHAR[(size, byteorder)],
523 is_native=is_native,
524 validate_defaults=validate_defaults,
525 dataclass_kwargs=dataclass_kwargs,
526 )
528 return decorator