diff --git a/libavcodec/av1_parse.c b/libavcodec/av1_parse.c index 50dd940f03..cdd524baa8 100644 --- a/libavcodec/av1_parse.c +++ b/libavcodec/av1_parse.c @@ -22,6 +22,7 @@ #include "libavutil/mem.h" +#include "av1.h" #include "av1_parse.h" #include "bytestream.h" @@ -29,16 +30,13 @@ int ff_av1_extract_obu(AV1OBU *obu, const uint8_t *buf, int length, void *logctx { int64_t obu_size; int start_pos, type, temporal_id, spatial_id; - int len, ret; + int len; len = parse_obu_header(buf, length, &obu_size, &start_pos, &type, &temporal_id, &spatial_id); if (len < 0) return len; - if (obu_size > INT_MAX / 8 || obu_size < 0) - return AVERROR(ERANGE); - obu->type = type; obu->temporal_id = temporal_id; obu->spatial_id = spatial_id; @@ -48,10 +46,6 @@ int ff_av1_extract_obu(AV1OBU *obu, const uint8_t *buf, int length, void *logctx obu->raw_data = buf; obu->raw_size = len; - ret = init_get_bits(&obu->gb, obu->data, obu->size * 8); - if (ret < 0) - return ret; - av_log(logctx, AV_LOG_DEBUG, "obu_type: %d, temporal_id: %d, spatial_id: %d, payload size: %d\n", obu->type, obu->temporal_id, obu->spatial_id, obu->size); @@ -62,7 +56,7 @@ int ff_av1_extract_obu(AV1OBU *obu, const uint8_t *buf, int length, void *logctx int ff_av1_packet_split(AV1Packet *pkt, const uint8_t *buf, int length, void *logctx) { GetByteContext bc; - int consumed; + int ret, consumed; bytestream2_init(&bc, buf, length); pkt->nb_obus = 0; @@ -87,9 +81,20 @@ int ff_av1_packet_split(AV1Packet *pkt, const uint8_t *buf, int length, void *lo if (consumed < 0) return consumed; + bytestream2_skip(&bc, consumed); + + obu->size_bits = get_obu_bit_length(obu->data, obu->size, obu->type); + + if (obu->size_bits < 0 || (!obu->size_bits && obu->type != AV1_OBU_TEMPORAL_DELIMITER)) { + av_log(logctx, AV_LOG_ERROR, "Invalid OBU of type %d, skipping.\n", obu->type); + continue; + } + pkt->nb_obus++; - bytestream2_skip(&bc, consumed); + ret = init_get_bits(&obu->gb, obu->data, obu->size_bits); + if (ret < 0) + return ret; } return 0; diff --git a/libavcodec/av1_parse.h b/libavcodec/av1_parse.h index 9a6e6835ab..0de619dbec 100644 --- a/libavcodec/av1_parse.h +++ b/libavcodec/av1_parse.h @@ -23,6 +23,7 @@ #include +#include "av1.h" #include "avcodec.h" #include "get_bits.h" @@ -31,6 +32,12 @@ typedef struct AV1OBU { int size; const uint8_t *data; + /** + * Size, in bits, of just the data, excluding the trailing_one_bit and + * any trailing padding. + */ + int size_bits; + /** Size of entire OBU, including header */ int raw_size; const uint8_t *raw_data; @@ -133,4 +140,35 @@ static inline int parse_obu_header(const uint8_t *buf, int buf_size, return size; } +static inline int get_obu_bit_length(const uint8_t *buf, int size, int type) +{ + int v; + + /* There are no trailing bits on these */ + if (type == AV1_OBU_TILE_GROUP || type == AV1_OBU_FRAME) { + if (size > INT_MAX / 8) + return AVERROR(ERANGE); + else + return size * 8; + } + + while (size > 0 && buf[size - 1] == 0) + size--; + + if (!size) + return 0; + + v = buf[size - 1]; + + if (size > INT_MAX / 8) + return AVERROR(ERANGE); + size *= 8; + + /* Remove the trailing_one_bit and following trailing zeros */ + if (v) + size -= ff_ctz(v) + 1; + + return size; +} + #endif /* AVCODEC_AV1_PARSE_H */