diff --git a/libavcodec/wmaprodec.c b/libavcodec/wmaprodec.c
index d06e3b73efc3b5a4b725d28b7739899855416683..105e27999edfb4b8e46227e1278fba83c224bc2b 100644
--- a/libavcodec/wmaprodec.c
+++ b/libavcodec/wmaprodec.c
@@ -322,12 +322,12 @@ static av_cold int decode_init(AVCodecContext *avctx)
     for (i = 0; i < avctx->extradata_size; i++)
         av_log(avctx, AV_LOG_DEBUG, "[%x] ", avctx->extradata[i]);
     av_log(avctx, AV_LOG_DEBUG, "\n");
-    if (avctx->codec_id == AV_CODEC_ID_XMA2 && avctx->extradata_size >= 34) {
+    if (avctx->codec_id == AV_CODEC_ID_XMA2 && (!avctx->extradata || avctx->extradata_size >= 6)) {
         s->decode_flags    = 0x10d6;
-        channel_mask       = AV_RL32(edata_ptr+2);
+        channel_mask       = avctx->extradata ? AV_RL32(edata_ptr+2) : 0;
         s->bits_per_sample = 16;
 
-     } else if (avctx->codec_id == AV_CODEC_ID_XMA1 && avctx->extradata_size >= 28) {
+     } else if (avctx->codec_id == AV_CODEC_ID_XMA1) {
         s->decode_flags    = 0x10d6;
         s->bits_per_sample = 16;
         channel_mask       = 0;