diff --git a/libavresample/audio_mix.c b/libavresample/audio_mix.c
index a3f11b657f7d51d02f943d97f2d18abd8330d2b4..fc37eac6a382f4f5a4896a921fec6db1868f90c8 100644
--- a/libavresample/audio_mix.c
+++ b/libavresample/audio_mix.c
@@ -527,28 +527,13 @@ int ff_audio_mix_get_matrix(AudioMix *am, double *matrix, int stride)
     return 0;
 }
 
-int ff_audio_mix_set_matrix(AudioMix *am, const double *matrix, int stride)
+static void reduce_matrix(AudioMix *am, const double *matrix, int stride)
 {
-    int i, o, i0, o0, ret;
-    char in_layout_name[128];
-    char out_layout_name[128];
-
-    if ( am->in_channels <= 0 ||  am->in_channels > AVRESAMPLE_MAX_CHANNELS ||
-        am->out_channels <= 0 || am->out_channels > AVRESAMPLE_MAX_CHANNELS) {
-        av_log(am->avr, AV_LOG_ERROR, "Invalid channel counts\n");
-        return AVERROR(EINVAL);
-    }
-
-    if (am->matrix) {
-        av_free(am->matrix[0]);
-        am->matrix = NULL;
-    }
+    int i, o;
 
     memset(am->output_zero, 0, sizeof(am->output_zero));
     memset(am->input_skip,  0, sizeof(am->input_skip));
-    memset(am->output_skip, 0, sizeof(am->output_zero));
-    am->in_matrix_channels  = am->in_channels;
-    am->out_matrix_channels = am->out_channels;
+    memset(am->output_skip, 0, sizeof(am->output_skip));
 
     /* exclude output channels if they can be zeroed instead of mixed */
     for (o = 0; o < am->out_channels; o++) {
@@ -578,7 +563,7 @@ int ff_audio_mix_set_matrix(AudioMix *am, const double *matrix, int stride)
     }
     if (am->out_matrix_channels == 0) {
         am->in_matrix_channels = 0;
-        return 0;
+        return;
     }
 
     /* skip input channels that contribute fully only to the corresponding
@@ -615,7 +600,7 @@ int ff_audio_mix_set_matrix(AudioMix *am, const double *matrix, int stride)
     }
     if (am->in_matrix_channels == 0) {
         am->out_matrix_channels = 0;
-        return 0;
+        return;
     }
 
     /* skip output channels that only get full contribution from the
@@ -637,8 +622,31 @@ int ff_audio_mix_set_matrix(AudioMix *am, const double *matrix, int stride)
     }
     if (am->out_matrix_channels == 0) {
         am->in_matrix_channels = 0;
-        return 0;
+        return;
     }
+}
+
+int ff_audio_mix_set_matrix(AudioMix *am, const double *matrix, int stride)
+{
+    int i, o, i0, o0, ret;
+    char in_layout_name[128];
+    char out_layout_name[128];
+
+    if ( am->in_channels <= 0 ||  am->in_channels > AVRESAMPLE_MAX_CHANNELS ||
+        am->out_channels <= 0 || am->out_channels > AVRESAMPLE_MAX_CHANNELS) {
+        av_log(am->avr, AV_LOG_ERROR, "Invalid channel counts\n");
+        return AVERROR(EINVAL);
+    }
+
+    if (am->matrix) {
+        av_free(am->matrix[0]);
+        am->matrix = NULL;
+    }
+
+    am->in_matrix_channels  = am->in_channels;
+    am->out_matrix_channels = am->out_channels;
+
+    reduce_matrix(am, matrix, stride);
 
 #define CONVERT_MATRIX(type, expr)                                          \
     am->matrix_## type[0] = av_mallocz(am->out_matrix_channels *            \
@@ -664,6 +672,7 @@ int ff_audio_mix_set_matrix(AudioMix *am, const double *matrix, int stride)
     }                                                                       \
     am->matrix = (void **)am->matrix_## type;
 
+    if (am->in_matrix_channels && am->out_matrix_channels) {
     switch (am->coeff_type) {
     case AV_MIX_COEFF_TYPE_Q8:
         CONVERT_MATRIX(q8, av_clip_int16(lrint(256.0 * v)))
@@ -678,6 +687,7 @@ int ff_audio_mix_set_matrix(AudioMix *am, const double *matrix, int stride)
         av_log(am->avr, AV_LOG_ERROR, "Invalid mix coeff type\n");
         return AVERROR(EINVAL);
     }
+    }
 
     ret = mix_function_init(am);
     if (ret < 0)