Coverage for dataclasses_struct/dataclass.py: 98%

229 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-15 10:33 +1200

1import dataclasses 

2import sys 

3from collections.abc import Generator, Iterator 

4from struct import Struct 

5from types import GenericAlias 

6from typing import ( 

7 Annotated, 

8 Any, 

9 Callable, 

10 Generic, 

11 Literal, 

12 Protocol, 

13 TypedDict, 

14 TypeVar, 

15 Union, 

16 get_args, 

17 get_origin, 

18 get_type_hints, 

19 overload, 

20) 

21 

22from ._typing import TypeGuard, Unpack, dataclass_transform 

23from .field import Field, builtin_fields 

24from .types import PadAfter, PadBefore 

25 

26 

27def _separate_padding_from_annotation_args(args) -> tuple[int, int, object]: 

28 pad_before = pad_after = 0 

29 extra_arg = None # should be Field or integer for bytes/list types 

30 for arg in args: 

31 if isinstance(arg, PadBefore): 

32 pad_before += arg.size 

33 elif isinstance(arg, PadAfter): 

34 pad_after += arg.size 

35 elif extra_arg is not None: 

36 raise TypeError(f"too many annotations: {arg}") 

37 else: 

38 extra_arg = arg 

39 

40 return pad_before, pad_after, extra_arg 

41 

42 

43def _format_str_with_padding(fmt: str, pad_before: int, pad_after: int) -> str: 

44 return "".join( 

45 ( 

46 (f"{pad_before}x" if pad_before else ""), 

47 fmt, 

48 (f"{pad_after}x" if pad_after else ""), 

49 ) 

50 ) 

51 

52 

53T = TypeVar("T") 

54 

55 

56_SIZE_BYTEORDER_MODE_CHAR: dict[tuple[str, str], str] = { 

57 ("native", "native"): "@", 

58 ("std", "native"): "=", 

59 ("std", "little"): "<", 

60 ("std", "big"): ">", 

61 ("std", "network"): "!", 

62} 

63_MODE_CHAR_SIZE_BYTEORDER: dict[str, tuple[str, str]] = { 

64 v: k for k, v in _SIZE_BYTEORDER_MODE_CHAR.items() 

65} 

66 

67 

68class _DataclassStructInternal(Generic[T]): 

69 struct: Struct 

70 cls: type[T] 

71 _fieldnames: list[str] 

72 _fields: list[tuple[Field[Any], type]] 

73 

74 @property 

75 def format(self) -> str: 

76 return self.struct.format 

77 

78 @property 

79 def size(self) -> int: 

80 return self.struct.size 

81 

82 @property 

83 def mode(self) -> str: 

84 return self.format[0] 

85 

86 def __init__( 

87 self, 

88 fmt: str, 

89 cls: type, 

90 fieldnames: list[str], 

91 fields: list[tuple[Field[Any], type]], 

92 ): 

93 self.struct = Struct(fmt) 

94 self.cls = cls 

95 self._fieldnames = fieldnames 

96 self._fields = fields 

97 

98 def _flattened_attrs(self, outer_self: T) -> list[Any]: 

99 """ 

100 Returns a list of all attributes of `outer_self`, including those of 

101 any nested structs. 

102 """ 

103 attrs: list[Any] = [] 

104 for fieldname in self._fieldnames: 

105 attr = getattr(outer_self, fieldname) 

106 self._flatten_attr(attrs, attr) 

107 return attrs 

108 

109 @staticmethod 

110 def _flatten_attr(attrs: list[Any], attr: object) -> None: 

111 if is_dataclass_struct(attr): 

112 attrs.extend(attr.__dataclass_struct__._flattened_attrs(attr)) 

113 elif isinstance(attr, list): 

114 for sub_attr in attr: 

115 _DataclassStructInternal._flatten_attr(attrs, sub_attr) 

116 else: 

117 attrs.append(attr) 

118 

119 def pack(self, obj: T) -> bytes: 

120 return self.struct.pack(*self._flattened_attrs(obj)) 

121 

122 def _arg_generator(self, args: Iterator) -> Generator: 

123 for field, fieldtype in self._fields: 

124 yield from _DataclassStructInternal._generate_args_recursively( 

125 args, field, fieldtype 

126 ) 

127 

128 @staticmethod 

129 def _generate_args_recursively( 

130 args: Iterator, 

131 field: Field[Any], 

132 field_type: type, 

133 ) -> Generator: 

134 if is_dataclass_struct(field_type): 

135 yield field_type.__dataclass_struct__._init_from_args(args) 

136 elif isinstance(field, _FixedLengthArrayField): 

137 items: list = [] 

138 for _ in range(field.n): 

139 items.extend( 

140 _DataclassStructInternal._generate_args_recursively( 

141 args, field.item_field, field.item_type 

142 ) 

143 ) 

144 yield items 

145 else: 

146 yield field_type(next(args)) 

147 

148 def _init_from_args(self, args: Iterator) -> T: 

149 """ 

150 Returns an instance of self.cls, consuming args 

151 """ 

152 return self.cls(*self._arg_generator(args)) 

153 

154 def unpack(self, data: bytes) -> T: 

155 return self._init_from_args(iter(self.struct.unpack(data))) 

156 

157 

158class DataclassStructProtocol(Protocol): 

159 __dataclass_struct__: _DataclassStructInternal 

160 

161 @classmethod 

162 def from_packed(cls: type[T], data: bytes) -> T: ... 

163 

164 def pack(self) -> bytes: ... 

165 

166 

167@overload 

168def is_dataclass_struct( 

169 obj: type, 

170) -> TypeGuard[type[DataclassStructProtocol]]: ... 

171 

172 

173@overload 

174def is_dataclass_struct(obj: object) -> TypeGuard[DataclassStructProtocol]: ... 

175 

176 

177def is_dataclass_struct( 

178 obj: Union[type, object], 

179) -> Union[ 

180 TypeGuard[DataclassStructProtocol], 

181 TypeGuard[type[DataclassStructProtocol]], 

182]: 

183 """ 

184 Returns True if obj is a class that has been decorated with 

185 dataclasses_struct.dataclass or an instance of one. 

186 """ 

187 return ( 

188 dataclasses.is_dataclass(obj) 

189 and hasattr(obj, "__dataclass_struct__") 

190 and isinstance(obj.__dataclass_struct__, _DataclassStructInternal) 

191 ) 

192 

193 

194def get_struct_size(cls_or_obj) -> int: 

195 """ 

196 Returns the size of the packed representation of the struct in bytes. 

197 Accepts either a class or an instance of a dataclass_struct. 

198 """ 

199 if not is_dataclass_struct(cls_or_obj): 

200 raise TypeError(f"{cls_or_obj} is not a dataclass_struct") 

201 return cls_or_obj.__dataclass_struct__.size 

202 

203 

204class _BytesField(Field[bytes]): 

205 field_type = bytes 

206 

207 def __init__(self, n: object): 

208 if not isinstance(n, int) or n < 1: 

209 raise ValueError("bytes length must be positive non-zero int") 

210 

211 self.n = n 

212 

213 def format(self) -> str: 

214 return f"{self.n}s" 

215 

216 def validate_default(self, val: bytes) -> None: 

217 if len(val) > self.n: 

218 raise ValueError(f"bytes cannot be longer than {self.n} bytes") 

219 

220 def __repr__(self) -> str: 

221 return f"{super().__repr__()}({self.n})" 

222 

223 

224class _NestedField(Field): 

225 field_type: type[DataclassStructProtocol] 

226 

227 def __init__(self, cls: type[DataclassStructProtocol]): 

228 self.field_type = cls 

229 

230 def format(self) -> str: 

231 # Return the format without the byteorder specifier at the beginning 

232 return self.field_type.__dataclass_struct__.format[1:] 

233 

234 

235class _FixedLengthArrayField(Field[list]): 

236 field_type = list 

237 

238 def __init__(self, item_type_annotation: Any, mode: str, n: object): 

239 if not isinstance(n, int) or n < 1: 

240 raise ValueError( 

241 "fixed-length array length must be positive non-zero int" 

242 ) 

243 

244 self.item_field, self.item_type, self.pad_before, self.pad_after = ( 

245 _resolve_field(item_type_annotation, mode) 

246 ) 

247 self.n = n 

248 self.is_native = self.item_field.is_native 

249 self.is_std = self.item_field.is_std 

250 

251 def format(self) -> str: 

252 fmt = _format_str_with_padding( 

253 self.item_field.format(), 

254 self.pad_before, 

255 self.pad_after, 

256 ) 

257 return fmt * self.n 

258 

259 def __repr__(self) -> str: 

260 return f"{super().__repr__()}({self.item_field!r}, {self.n})" 

261 

262 

263def _validate_modes_match(mode: str, nested_mode: str) -> None: 

264 if mode != nested_mode: 

265 size, byteorder = _MODE_CHAR_SIZE_BYTEORDER[nested_mode] 

266 exp_size, exp_byteorder = _MODE_CHAR_SIZE_BYTEORDER[mode] 

267 msg = ( 

268 "byteorder and size of nested dataclass-struct does not " 

269 f"match that of container (expected '{exp_size}' size and " 

270 f"'{exp_byteorder}' byteorder, got '{size}' size and " 

271 f"'{byteorder}' byteorder)" 

272 ) 

273 raise TypeError(msg) 

274 

275 

276def _resolve_field( 

277 annotation: Any, 

278 mode: str, 

279) -> tuple[Field[Any], type, int, int]: 

280 """ 

281 Returns 4-tuple of: 

282 * field 

283 * type 

284 * number of padding bytes before 

285 * number of padding bytes after 

286 

287 Valid type annotations are: 

288 

289 1. <bool | int | float | bytes> | Annotated[<bool | int | float | bytes>, <padding>] 

290 

291 Supported builtin types. 

292 

293 2. Annotated[<bool | int | float | bytes>, Field(...), <padding>] 

294 

295 (These are the types defined in dataclasses_struct.types e.g. U32, F32). 

296 

297 3. <dataclasses_struct class> | Annotated[<dataclasses_struct class>, <padding>] 

298 

299 Must have the same size and byteorder as the container. 

300 

301 4. Annotated[bytes, <n>, <padding>] 

302 

303 Where <n> is >0. 

304 

305 5. Annotated[list[<type>], <n>, <padding>] 

306 

307 Where <n> is >0 and <type> is one of the above. 

308 

309 <padding> is an optional mixture of PadBefore and PadAfter annotations, 

310 which may be repeated. E.g. 

311 

312 Annotated[int, PadBefore(5), PadAfter(2), PadBefore(3)] 

313 """ # noqa: E501 

314 

315 if get_origin(annotation) == Annotated: 

316 type_, *args = get_args(annotation) 

317 pad_before, pad_after, annotation_arg = ( 

318 _separate_padding_from_annotation_args(args) 

319 ) 

320 else: 

321 pad_before = pad_after = 0 

322 type_ = annotation 

323 annotation_arg = None 

324 

325 field: Field[Any] 

326 if annotation_arg is None: 

327 if get_origin(type_) is list: 

328 msg = ( 

329 "list types must be marked as a fixed-length using " 

330 "Annotated, ex: Annotated[list[int], 5]" 

331 ) 

332 raise TypeError(msg) 

333 

334 # Must be either a nested type or one of the supported builtins 

335 if is_dataclass_struct(type_): 

336 _validate_modes_match(mode, type_.__dataclass_struct__.mode) 

337 field = _NestedField(type_) 

338 else: 

339 opt_field = builtin_fields.get(type_) 

340 if opt_field is None: 

341 raise TypeError(f"type not supported: {annotation}") 

342 field = opt_field 

343 elif isinstance(annotation_arg, Field): 

344 field = annotation_arg 

345 elif get_origin(type_) is list: 

346 item_annotations = get_args(type_) 

347 assert len(item_annotations) == 1 

348 field = _FixedLengthArrayField( 

349 item_annotations[0], mode, annotation_arg 

350 ) 

351 elif issubclass(type_, bytes): 

352 field = _BytesField(annotation_arg) 

353 else: 

354 raise TypeError(f"invalid field annotation: {annotation!r}") 

355 

356 return field, type_, pad_before, pad_after 

357 

358 

359def _validate_and_parse_field( 

360 cls: type, 

361 name: str, 

362 field_type: type, 

363 is_native: bool, 

364 validate_defaults: bool, 

365 mode: str, 

366) -> tuple[str, Field, type]: 

367 field, type_, pad_before, pad_after = _resolve_field(field_type, mode) 

368 

369 if is_native: 

370 if not field.is_native: 

371 raise TypeError( 

372 f"field {field} only supported in standard size mode" 

373 ) 

374 elif not field.is_std: 

375 raise TypeError(f"field {field} only supported in native size mode") 

376 

377 if validate_defaults and hasattr(cls, name): 

378 val = getattr(cls, name) 

379 if not isinstance(field.field_type, GenericAlias) and not isinstance( 

380 val, field.field_type 

381 ): 

382 raise TypeError( 

383 "invalid type for field: expected " 

384 f"{field.field_type} got {type(val)}" 

385 ) 

386 field.validate_default(val) 

387 

388 return ( 

389 _format_str_with_padding(field.format(), pad_before, pad_after), 

390 field, 

391 type_, 

392 ) 

393 

394 

395def _make_pack_method() -> Callable: 

396 func = """ 

397def pack(self) -> bytes: 

398 '''Pack to bytes using struct.pack.''' 

399 return self.__dataclass_struct__.pack(self) 

400""" 

401 

402 scope: dict[str, Any] = {} 

403 exec(func, {}, scope) 

404 return scope["pack"] 

405 

406 

407def _make_unpack_method(cls: type) -> classmethod: 

408 func = """ 

409def from_packed(cls, data: bytes) -> cls_type: 

410 '''Unpack from bytes.''' 

411 return cls.__dataclass_struct__.unpack(data) 

412""" 

413 

414 scope: dict[str, Any] = {"cls_type": cls} 

415 exec(func, {}, scope) 

416 return classmethod(scope["from_packed"]) 

417 

418 

419def _make_class( 

420 cls: type, 

421 mode: str, 

422 is_native: bool, 

423 validate_defaults: bool, 

424 dataclass_kwargs, 

425) -> type[DataclassStructProtocol]: 

426 cls_annotations = get_type_hints(cls, include_extras=True) 

427 struct_format = [mode] 

428 fields = [] 

429 for name, field in cls_annotations.items(): 

430 fmt, field, type_ = _validate_and_parse_field( 

431 cls, 

432 name=name, 

433 field_type=field, 

434 is_native=is_native, 

435 validate_defaults=validate_defaults, 

436 mode=mode, 

437 ) 

438 struct_format.append(fmt) 

439 fields.append((field, type_)) 

440 

441 setattr( # noqa: B010 

442 cls, 

443 "__dataclass_struct__", 

444 _DataclassStructInternal( 

445 "".join(struct_format), 

446 cls, 

447 list(cls_annotations.keys()), 

448 fields, 

449 ), 

450 ) 

451 setattr(cls, "pack", _make_pack_method()) # noqa: B010 

452 setattr(cls, "from_packed", _make_unpack_method(cls)) # noqa: B010 

453 

454 return dataclasses.dataclass(cls, **dataclass_kwargs) 

455 

456 

457class _DataclassKwargsPre310(TypedDict, total=False): 

458 init: bool 

459 repr: bool 

460 eq: bool 

461 order: bool 

462 unsafe_hash: bool 

463 frozen: bool 

464 

465 

466if sys.version_info >= (3, 10): 

467 

468 class DataclassKwargs(_DataclassKwargsPre310, total=False): 

469 match_args: bool 

470 kw_only: bool 

471else: 

472 

473 class DataclassKwargs(_DataclassKwargsPre310, total=False): 

474 pass 

475 

476 

477@overload 

478def dataclass_struct( 

479 *, 

480 size: Literal["native"] = "native", 

481 byteorder: Literal["native"] = "native", 

482 validate_defaults: bool = True, 

483 **dataclass_kwargs: Unpack[DataclassKwargs], 

484) -> Callable[[type], type]: ... 

485 

486 

487@overload 

488def dataclass_struct( 

489 *, 

490 size: Literal["std"], 

491 byteorder: Literal["native", "big", "little", "network"] = "native", 

492 validate_defaults: bool = True, 

493 **dataclass_kwargs: Unpack[DataclassKwargs], 

494) -> Callable[[type], type]: ... 

495 

496 

497@dataclass_transform() 

498def dataclass_struct( 

499 *, 

500 size: Literal["native", "std"] = "native", 

501 byteorder: Literal["native", "big", "little", "network"] = "native", 

502 validate_defaults: bool = True, 

503 **dataclass_kwargs: Unpack[DataclassKwargs], 

504) -> Callable[[type], type]: 

505 is_native = size == "native" 

506 if is_native: 

507 if byteorder != "native": 

508 raise ValueError("'native' size requires 'native' byteorder") 

509 elif size != "std": 

510 raise ValueError(f"invalid size: {size}") 

511 if byteorder not in ("native", "big", "little", "network"): 

512 raise ValueError(f"invalid byteorder: {byteorder}") 

513 

514 for kwarg in ("slots", "weakref_slot"): 

515 if dataclass_kwargs.get(kwarg): 

516 msg = f"dataclass '{kwarg}' keyword argument is not supported" 

517 raise ValueError(msg) 

518 

519 def decorator(cls: type) -> type: 

520 return _make_class( 

521 cls, 

522 mode=_SIZE_BYTEORDER_MODE_CHAR[(size, byteorder)], 

523 is_native=is_native, 

524 validate_defaults=validate_defaults, 

525 dataclass_kwargs=dataclass_kwargs, 

526 ) 

527 

528 return decorator