Moved nesting decrement logic to class NestingLimit

This commit is contained in:
Benoit Blanchon
2020-02-13 16:54:18 +01:00
parent 6e52f242b2
commit fbffadb2cf
4 changed files with 88 additions and 84 deletions

View File

@@ -20,15 +20,16 @@ class MsgPackDeserializer {
public:
MsgPackDeserializer(MemoryPool &pool, TReader reader,
TStringStorage stringStorage, uint8_t nestingLimit)
: _pool(&pool),
_reader(reader),
_stringStorage(stringStorage),
_nestingLimit(nestingLimit) {}
TStringStorage stringStorage)
: _pool(&pool), _reader(reader), _stringStorage(stringStorage) {}
// TODO: add support for filter
DeserializationError parse(VariantData &variant,
AllowAllFilter = AllowAllFilter()) {
DeserializationError parse(VariantData &variant, AllowAllFilter,
NestingLimit nestingLimit) {
return parse(variant, nestingLimit);
}
DeserializationError parse(VariantData &variant, NestingLimit nestingLimit) {
uint8_t code;
if (!readByte(code)) return DeserializationError::IncompleteInput;
@@ -48,11 +49,11 @@ class MsgPackDeserializer {
}
if ((code & 0xf0) == 0x90) {
return readArray(variant.toArray(), code & 0x0F);
return readArray(variant.toArray(), code & 0x0F, nestingLimit);
}
if ((code & 0xf0) == 0x80) {
return readObject(variant.toObject(), code & 0x0F);
return readObject(variant.toObject(), code & 0x0F, nestingLimit);
}
switch (code) {
@@ -116,16 +117,16 @@ class MsgPackDeserializer {
return readString<uint32_t>(variant);
case 0xdc:
return readArray<uint16_t>(variant.toArray());
return readArray<uint16_t>(variant.toArray(), nestingLimit);
case 0xdd:
return readArray<uint32_t>(variant.toArray());
return readArray<uint32_t>(variant.toArray(), nestingLimit);
case 0xde:
return readObject<uint16_t>(variant.toObject());
return readObject<uint16_t>(variant.toObject(), nestingLimit);
case 0xdf:
return readObject<uint32_t>(variant.toObject());
return readObject<uint32_t>(variant.toObject(), nestingLimit);
default:
return DeserializationError::NotSupported;
@@ -242,36 +243,40 @@ class MsgPackDeserializer {
}
template <typename TSize>
DeserializationError readArray(CollectionData &array) {
DeserializationError readArray(CollectionData &array,
NestingLimit nestingLimit) {
TSize size;
if (!readInteger(size)) return DeserializationError::IncompleteInput;
return readArray(array, size);
return readArray(array, size, nestingLimit);
}
DeserializationError readArray(CollectionData &array, size_t n) {
if (_nestingLimit == 0) return DeserializationError::TooDeep;
--_nestingLimit;
DeserializationError readArray(CollectionData &array, size_t n,
NestingLimit nestingLimit) {
if (nestingLimit.reached()) return DeserializationError::TooDeep;
for (; n; --n) {
VariantData *value = array.add(_pool);
if (!value) return DeserializationError::NoMemory;
DeserializationError err = parse(*value);
DeserializationError err = parse(*value, nestingLimit.decrement());
if (err) return err;
}
++_nestingLimit;
return DeserializationError::Ok;
}
template <typename TSize>
DeserializationError readObject(CollectionData &object) {
DeserializationError readObject(CollectionData &object,
NestingLimit nestingLimit) {
TSize size;
if (!readInteger(size)) return DeserializationError::IncompleteInput;
return readObject(object, size);
return readObject(object, size, nestingLimit);
}
DeserializationError readObject(CollectionData &object, size_t n) {
if (_nestingLimit == 0) return DeserializationError::TooDeep;
--_nestingLimit;
DeserializationError readObject(CollectionData &object, size_t n,
NestingLimit nestingLimit) {
if (nestingLimit.reached()) return DeserializationError::TooDeep;
for (; n; --n) {
VariantSlot *slot = object.addSlot(_pool);
if (!slot) return DeserializationError::NoMemory;
@@ -281,10 +286,10 @@ class MsgPackDeserializer {
if (err) return err;
slot->setOwnedKey(make_not_null(key));
err = parse(*slot->data());
err = parse(*slot->data(), nestingLimit.decrement());
if (err) return err;
}
++_nestingLimit;
return DeserializationError::Ok;
}
@@ -312,7 +317,6 @@ class MsgPackDeserializer {
MemoryPool *_pool;
TReader _reader;
TStringStorage _stringStorage;
uint8_t _nestingLimit;
};
template <typename TInput>