Init resamplers
[rtaudio-cdist.git] / RtAudio.cpp
index ea0c202d737446748549d21c0da3282a2118b606..7fb2798f90092f7333e147f036da8141900a9105 100644 (file)
@@ -3693,7 +3693,14 @@ static const char* getAsioErrorString( ASIOError result )
 #include <avrt.h>
 #include <mmdeviceapi.h>
 #include <functiondiscoverykeys_devpkey.h>
-#include <sstream>
+
+#include <mfapi.h>
+#include <mferror.h>
+#include <mfplay.h>
+#include <Wmcodecdsp.h>
+
+#pragma comment( lib, "mfplat.lib" )
+#pragma comment( lib, "wmcodecdspuuid" )
 
 //=============================================================================
 
@@ -3867,6 +3874,187 @@ private:
 
 //-----------------------------------------------------------------------------
 
+// In order to satisfy WASAPI's buffer requirements, we need a means of converting sample rate
+// between HW and the user. The WasapiResampler class is used to perform this conversion between
+// HwIn->UserIn and UserOut->HwOut during the stream callback loop.
+class WasapiResampler
+{
+public:
+  WasapiResampler( bool isFloat, unsigned int bitsPerSample, unsigned int channelCount,
+                   unsigned int inSampleRate, unsigned int outSampleRate )
+    : _bytesPerSample( bitsPerSample / 8 )
+    , _channelCount( channelCount )
+    , _sampleRatio( ( float ) outSampleRate / inSampleRate )
+    , _transformUnk( NULL )
+    , _transform( NULL )
+    , _resamplerProps( NULL )
+    , _mediaType( NULL )
+    , _inputMediaType( NULL )
+    , _outputMediaType( NULL )
+  {
+    // 1. Initialization
+
+    MFStartup( MF_VERSION, MFSTARTUP_NOSOCKET );
+
+    // 2. Create Resampler Transform Object
+
+    CoCreateInstance( CLSID_CResamplerMediaObject, NULL, CLSCTX_INPROC_SERVER,
+                      IID_IUnknown, ( void** ) &_transformUnk );
+
+    _transformUnk->QueryInterface( IID_PPV_ARGS( &_transform ) );
+
+    _transformUnk->QueryInterface( IID_PPV_ARGS( &_resamplerProps ) );
+    _resamplerProps->SetHalfFilterLength( 60 ); // best conversion quality
+
+                                                // 3. Specify input / output format
+
+    MFCreateMediaType( &_mediaType );
+    _mediaType->SetGUID( MF_MT_MAJOR_TYPE, MFMediaType_Audio );
+    _mediaType->SetGUID( MF_MT_SUBTYPE, isFloat ? MFAudioFormat_Float : MFAudioFormat_PCM );
+    _mediaType->SetUINT32( MF_MT_AUDIO_NUM_CHANNELS, channelCount );
+    _mediaType->SetUINT32( MF_MT_AUDIO_SAMPLES_PER_SECOND, inSampleRate );
+    _mediaType->SetUINT32( MF_MT_AUDIO_BLOCK_ALIGNMENT, _bytesPerSample * channelCount );
+    _mediaType->SetUINT32( MF_MT_AUDIO_AVG_BYTES_PER_SECOND, _bytesPerSample * channelCount * inSampleRate );
+    _mediaType->SetUINT32( MF_MT_AUDIO_BITS_PER_SAMPLE, bitsPerSample );
+    _mediaType->SetUINT32( MF_MT_ALL_SAMPLES_INDEPENDENT, TRUE );
+
+    MFCreateMediaType( &_inputMediaType );
+    _mediaType->CopyAllItems( _inputMediaType );
+
+    _transform->SetInputType( 0, _inputMediaType, 0 );
+
+    MFCreateMediaType( &_outputMediaType );
+    _mediaType->CopyAllItems( _outputMediaType );
+
+    _outputMediaType->SetUINT32( MF_MT_AUDIO_SAMPLES_PER_SECOND, outSampleRate );
+    _outputMediaType->SetUINT32( MF_MT_AUDIO_AVG_BYTES_PER_SECOND, _bytesPerSample * channelCount * outSampleRate );
+
+    _transform->SetOutputType( 0, _outputMediaType, 0 );
+
+    // 4. Send stream start messages to Resampler
+
+    _transform->ProcessMessage( MFT_MESSAGE_COMMAND_FLUSH, NULL );
+    _transform->ProcessMessage( MFT_MESSAGE_NOTIFY_BEGIN_STREAMING, NULL );
+    _transform->ProcessMessage( MFT_MESSAGE_NOTIFY_START_OF_STREAM, NULL );
+  }
+
+  ~WasapiResampler()
+  {
+    // 8. Send stream stop messages to Resampler
+
+    _transform->ProcessMessage( MFT_MESSAGE_NOTIFY_END_OF_STREAM, NULL );
+    _transform->ProcessMessage( MFT_MESSAGE_NOTIFY_END_STREAMING, NULL );
+
+    // 9. Cleanup
+
+    MFShutdown();
+
+    SAFE_RELEASE( _transformUnk );
+    SAFE_RELEASE( _transform );
+    SAFE_RELEASE( _resamplerProps );
+    SAFE_RELEASE( _mediaType );
+    SAFE_RELEASE( _inputMediaType );
+    SAFE_RELEASE( _outputMediaType );
+  }
+
+  void Convert( char* outBuffer, const char* inBuffer, unsigned int inSampleCount, unsigned int& outSampleCount )
+  {
+    unsigned int inputBufferSize = _bytesPerSample * _channelCount * inSampleCount;
+    if ( _sampleRatio == 1 )
+    {
+      // no sample rate conversion required
+      memcpy( outBuffer, inBuffer, inputBufferSize );
+      outSampleCount = inSampleCount;
+      return;
+    }
+
+    unsigned int outputBufferSize = ( unsigned int ) ceilf( inputBufferSize * _sampleRatio ) + ( _bytesPerSample * _channelCount );
+
+    IMFMediaBuffer* rInBuffer;
+    IMFSample* rInSample;
+    BYTE* rInByteBuffer = NULL;
+
+    // 5. Create Sample object from input data
+
+    MFCreateMemoryBuffer( inputBufferSize, &rInBuffer );
+
+    rInBuffer->Lock( &rInByteBuffer, NULL, NULL );
+    memcpy( rInByteBuffer, inBuffer, inputBufferSize );
+    rInBuffer->Unlock();
+    rInByteBuffer = NULL;
+
+    rInBuffer->SetCurrentLength( inputBufferSize );
+
+    MFCreateSample( &rInSample );
+    rInSample->AddBuffer( rInBuffer );
+
+    // 6. Pass input data to Resampler
+
+    _transform->ProcessInput( 0, rInSample, 0 );
+
+    SAFE_RELEASE( rInBuffer );
+    SAFE_RELEASE( rInSample );
+
+    // 7. Perform sample rate conversion
+
+    IMFMediaBuffer* rOutBuffer = NULL;
+    BYTE* rOutByteBuffer = NULL;
+
+    MFT_OUTPUT_DATA_BUFFER rOutDataBuffer;
+    DWORD rStatus;
+    DWORD rBytes = outputBufferSize; // maximum bytes accepted per ProcessOutput
+
+                                     // 7.1 Create Sample object for output data
+
+    memset( &rOutDataBuffer, 0, sizeof rOutDataBuffer );
+    MFCreateSample( &( rOutDataBuffer.pSample ) );
+    MFCreateMemoryBuffer( rBytes, &rOutBuffer );
+    rOutDataBuffer.pSample->AddBuffer( rOutBuffer );
+    rOutDataBuffer.dwStreamID = 0;
+    rOutDataBuffer.dwStatus = 0;
+    rOutDataBuffer.pEvents = NULL;
+
+    // 7.2 Get output data from Resampler
+
+    if ( _transform->ProcessOutput( 0, 1, &rOutDataBuffer, &rStatus ) == MF_E_TRANSFORM_NEED_MORE_INPUT )
+    {
+      outSampleCount = 0;
+      SAFE_RELEASE( rOutBuffer );
+      SAFE_RELEASE( rOutDataBuffer.pSample );
+      return;
+    }
+
+    // 7.3 Write output data to outBuffer
+
+    SAFE_RELEASE( rOutBuffer );
+    rOutDataBuffer.pSample->ConvertToContiguousBuffer( &rOutBuffer );
+    rOutBuffer->GetCurrentLength( &rBytes );
+
+    rOutBuffer->Lock( &rOutByteBuffer, NULL, NULL );
+    memcpy( outBuffer, rOutByteBuffer, rBytes );
+    rOutBuffer->Unlock();
+    rOutByteBuffer = NULL;
+
+    outSampleCount = rBytes / _bytesPerSample / _channelCount;
+    SAFE_RELEASE( rOutBuffer );
+    SAFE_RELEASE( rOutDataBuffer.pSample );
+  }
+
+private:
+  unsigned int _bytesPerSample;
+  unsigned int _channelCount;
+  float _sampleRatio;
+
+  IUnknown* _transformUnk;
+  IMFTransform* _transform;
+  IWMResamplerProps* _resamplerProps;
+  IMFMediaType* _mediaType;
+  IMFMediaType* _inputMediaType;
+  IMFMediaType* _outputMediaType;
+};
+
+//-----------------------------------------------------------------------------
+
 // A structure to hold various information related to the WASAPI implementation.
 struct WasapiHandle
 {
@@ -4132,11 +4320,14 @@ RtAudio::DeviceInfo RtApiWasapi::getDeviceInfo( unsigned int device )
     info.duplexChannels = 0;
   }
 
-  // sample rates (WASAPI only supports the one native sample rate)
-  info.preferredSampleRate = deviceFormat->nSamplesPerSec;
-
+  // sample rates
   info.sampleRates.clear();
-  info.sampleRates.push_back( deviceFormat->nSamplesPerSec );
+
+  // allow support for all sample rates as we have a built-in sample rate converter
+  for ( unsigned int i = 0; i < MAX_SAMPLE_RATES; i++ ) {
+    info.sampleRates.push_back( SAMPLE_RATES[i] );
+  }
+  info.preferredSampleRate = deviceFormat->nSamplesPerSec;
 
   // native format
   info.nativeFormats = 0;
@@ -4413,7 +4604,6 @@ bool RtApiWasapi::probeDeviceOpen( unsigned int device, StreamMode mode, unsigne
   WAVEFORMATEX* deviceFormat = NULL;
   unsigned int bufferBytes;
   stream_.state = STREAM_STOPPED;
-  RtAudio::DeviceInfo deviceInfo;
 
   // create API Handle if not already created
   if ( !stream_.apiHandle )
@@ -4454,20 +4644,6 @@ bool RtApiWasapi::probeDeviceOpen( unsigned int device, StreamMode mode, unsigne
     goto Exit;
   }
 
-  deviceInfo = getDeviceInfo( device );
-
-  // validate sample rate
-  if ( sampleRate != deviceInfo.preferredSampleRate )
-  {
-    errorType = RtAudioError::INVALID_USE;
-    std::stringstream ss;
-    ss << "RtApiWasapi::probeDeviceOpen: " << sampleRate
-       << "Hz sample rate not supported. This device only supports "
-       << deviceInfo.preferredSampleRate << "Hz.";
-    errorText_ = ss.str();
-    goto Exit;
-  }
-
   // determine whether index falls within capture or render devices
   if ( device >= renderDeviceCount ) {
     if ( mode != INPUT ) {
@@ -4551,7 +4727,7 @@ bool RtApiWasapi::probeDeviceOpen( unsigned int device, StreamMode mode, unsigne
   stream_.nUserChannels[mode] = channels;
   stream_.channelOffset[mode] = firstChannel;
   stream_.userFormat = format;
-  stream_.deviceFormat[mode] = deviceInfo.nativeFormats;
+  stream_.deviceFormat[mode] = getDeviceInfo( device ).nativeFormats;
 
   if ( options && options->flags & RTAUDIO_NONINTERLEAVED )
     stream_.userInterleaved = false;
@@ -4651,8 +4827,12 @@ void RtApiWasapi::wasapiThread()
 
   WAVEFORMATEX* captureFormat = NULL;
   WAVEFORMATEX* renderFormat = NULL;
+  float captureSrRatio = 0.0f;
+  float renderSrRatio = 0.0f;
   WasapiBuffer captureBuffer;
   WasapiBuffer renderBuffer;
+  WasapiResampler* captureResampler = NULL;
+  WasapiResampler* renderResampler = NULL;
 
   // declare local stream variables
   RtAudioCallback callback = ( RtAudioCallback ) stream_.callbackInfo.callback;
@@ -4660,11 +4840,15 @@ void RtApiWasapi::wasapiThread()
   unsigned long captureFlags = 0;
   unsigned int bufferFrameCount = 0;
   unsigned int numFramesPadding = 0;
-  bool callbackPushed = false;
+  unsigned int convBufferSize = 0;
+  bool callbackPushed = true;
   bool callbackPulled = false;
   bool callbackStopped = false;
   int callbackResult = 0;
 
+  // convBuffer is used to store converted buffers between WASAPI and the user
+  char* convBuffer = NULL;
+  unsigned int convBuffSize = 0;
   unsigned int deviceBuffSize = 0;
 
   errorText_.clear();
@@ -4687,8 +4871,16 @@ void RtApiWasapi::wasapiThread()
       goto Exit;
     }
 
+    // init captureResampler
+    captureResampler = new WasapiResampler( stream_.deviceFormat[INPUT] == RTAUDIO_FLOAT32 || stream_.deviceFormat[INPUT] == RTAUDIO_FLOAT64,
+                                            formatBytes( stream_.deviceFormat[INPUT] ) * 8, stream_.nDeviceChannels[INPUT],
+                                            captureFormat->nSamplesPerSec, stream_.sampleRate );
+
+    captureSrRatio = ( ( float ) captureFormat->nSamplesPerSec / stream_.sampleRate );
+
     // initialize capture stream according to desire buffer size
-    REFERENCE_TIME desiredBufferPeriod = ( REFERENCE_TIME ) ( ( float ) stream_.bufferSize * 10000000 / captureFormat->nSamplesPerSec );
+    float desiredBufferSize = stream_.bufferSize * captureSrRatio;
+    REFERENCE_TIME desiredBufferPeriod = ( REFERENCE_TIME ) ( ( float ) desiredBufferSize * 10000000 / captureFormat->nSamplesPerSec );
 
     if ( !captureClient ) {
       hr = captureAudioClient->Initialize( AUDCLNT_SHAREMODE_SHARED,
@@ -4735,7 +4927,7 @@ void RtApiWasapi::wasapiThread()
     }
 
     // scale outBufferSize according to stream->user sample rate ratio
-    unsigned int outBufferSize = ( unsigned int ) stream_.bufferSize * stream_.nDeviceChannels[INPUT];
+    unsigned int outBufferSize = ( unsigned int ) ceilf( stream_.bufferSize * captureSrRatio ) * stream_.nDeviceChannels[INPUT];
     inBufferSize *= stream_.nDeviceChannels[INPUT];
 
     // set captureBuffer size
@@ -4764,8 +4956,16 @@ void RtApiWasapi::wasapiThread()
       goto Exit;
     }
 
+    // init renderResampler
+    renderResampler = new WasapiResampler( stream_.deviceFormat[OUTPUT] == RTAUDIO_FLOAT32 || stream_.deviceFormat[OUTPUT] == RTAUDIO_FLOAT64,
+                                           formatBytes( stream_.deviceFormat[OUTPUT] ) * 8, stream_.nDeviceChannels[OUTPUT],
+                                           stream_.sampleRate, renderFormat->nSamplesPerSec );
+
+    renderSrRatio = ( ( float ) renderFormat->nSamplesPerSec / stream_.sampleRate );
+
     // initialize render stream according to desire buffer size
-    REFERENCE_TIME desiredBufferPeriod = ( REFERENCE_TIME ) ( ( float ) stream_.bufferSize * 10000000 / renderFormat->nSamplesPerSec );
+    float desiredBufferSize = stream_.bufferSize * renderSrRatio;
+    REFERENCE_TIME desiredBufferPeriod = ( REFERENCE_TIME ) ( ( float ) desiredBufferSize * 10000000 / renderFormat->nSamplesPerSec );
 
     if ( !renderClient ) {
       hr = renderAudioClient->Initialize( AUDCLNT_SHAREMODE_SHARED,
@@ -4812,7 +5012,7 @@ void RtApiWasapi::wasapiThread()
     }
 
     // scale inBufferSize according to user->stream sample rate ratio
-    unsigned int inBufferSize = ( unsigned int ) stream_.bufferSize * stream_.nDeviceChannels[OUTPUT];
+    unsigned int inBufferSize = ( unsigned int ) ceilf( stream_.bufferSize * renderSrRatio ) * stream_.nDeviceChannels[OUTPUT];
     outBufferSize *= stream_.nDeviceChannels[OUTPUT];
 
     // set renderBuffer size
@@ -4833,20 +5033,30 @@ void RtApiWasapi::wasapiThread()
     }
   }
 
-  if ( stream_.mode == INPUT ) {
-    using namespace std; // for roundf
+  // malloc buffer memory
+  if ( stream_.mode == INPUT )
+  {
+    using namespace std; // for ceilf
+    convBuffSize = ( size_t ) ( ceilf( stream_.bufferSize * captureSrRatio ) ) * stream_.nDeviceChannels[INPUT] * formatBytes( stream_.deviceFormat[INPUT] );
     deviceBuffSize = stream_.bufferSize * stream_.nDeviceChannels[INPUT] * formatBytes( stream_.deviceFormat[INPUT] );
   }
-  else if ( stream_.mode == OUTPUT ) {
+  else if ( stream_.mode == OUTPUT )
+  {
+    convBuffSize = ( size_t ) ( ceilf( stream_.bufferSize * renderSrRatio ) ) * stream_.nDeviceChannels[OUTPUT] * formatBytes( stream_.deviceFormat[OUTPUT] );
     deviceBuffSize = stream_.bufferSize * stream_.nDeviceChannels[OUTPUT] * formatBytes( stream_.deviceFormat[OUTPUT] );
   }
-  else if ( stream_.mode == DUPLEX ) {
+  else if ( stream_.mode == DUPLEX )
+  {
+    convBuffSize = std::max( ( size_t ) ( ceilf( stream_.bufferSize * captureSrRatio ) ) * stream_.nDeviceChannels[INPUT] * formatBytes( stream_.deviceFormat[INPUT] ),
+                             ( size_t ) ( ceilf( stream_.bufferSize * renderSrRatio ) ) * stream_.nDeviceChannels[OUTPUT] * formatBytes( stream_.deviceFormat[OUTPUT] ) );
     deviceBuffSize = std::max( stream_.bufferSize * stream_.nDeviceChannels[INPUT] * formatBytes( stream_.deviceFormat[INPUT] ),
                                stream_.bufferSize * stream_.nDeviceChannels[OUTPUT] * formatBytes( stream_.deviceFormat[OUTPUT] ) );
   }
 
+  convBuffSize *= 2; // allow overflow for *SrRatio remainders
+  convBuffer = ( char* ) malloc( convBuffSize );
   stream_.deviceBuffer = ( char* ) malloc( deviceBuffSize );
-  if ( !stream_.deviceBuffer ) {
+  if ( !convBuffer || !stream_.deviceBuffer ) {
     errorType = RtAudioError::MEMORY_ERROR;
     errorText_ = "RtApiWasapi::wasapiThread: Error allocating device buffer memory.";
     goto Exit;
@@ -4858,15 +5068,26 @@ void RtApiWasapi::wasapiThread()
       // Callback Input
       // ==============
       // 1. Pull callback buffer from inputBuffer
-      // 2. If 1. was successful: Convert callback buffer to user format
+      // 2. If 1. was successful: Convert callback buffer to user sample rate and channel count
+      //                          Convert callback buffer to user format
 
       if ( captureAudioClient ) {
         // Pull callback buffer from inputBuffer
-        callbackPulled = captureBuffer.pullBuffer( stream_.deviceBuffer,
-                                                   ( unsigned int ) stream_.bufferSize * stream_.nDeviceChannels[INPUT],
+        callbackPulled = captureBuffer.pullBuffer( convBuffer,
+                                                   ( unsigned int ) ( stream_.bufferSize * captureSrRatio ) * stream_.nDeviceChannels[INPUT],
                                                    stream_.deviceFormat[INPUT] );
 
         if ( callbackPulled ) {
+          // Convert callback buffer to user sample rate
+          convertBufferWasapi( stream_.deviceBuffer,
+                               convBuffer,
+                               stream_.nDeviceChannels[INPUT],
+                               captureFormat->nSamplesPerSec,
+                               stream_.sampleRate,
+                               ( unsigned int ) ( stream_.bufferSize * captureSrRatio ),
+                               convBufferSize,
+                               stream_.deviceFormat[INPUT] );
+
           if ( stream_.doConvertBuffer[INPUT] ) {
             // Convert callback buffer to user format
             convertBuffer( stream_.userBuffer[INPUT],
@@ -4940,7 +5161,8 @@ void RtApiWasapi::wasapiThread()
     // Callback Output
     // ===============
     // 1. Convert callback buffer to stream format
-    // 2. Push callback buffer into outputBuffer
+    // 2. Convert callback buffer to stream sample rate and channel count
+    // 3. Push callback buffer into outputBuffer
 
     if ( renderAudioClient && callbackPulled ) {
       if ( stream_.doConvertBuffer[OUTPUT] ) {
@@ -4951,9 +5173,19 @@ void RtApiWasapi::wasapiThread()
 
       }
 
+      // Convert callback buffer to stream sample rate
+      convertBufferWasapi( convBuffer,
+                           stream_.deviceBuffer,
+                           stream_.nDeviceChannels[OUTPUT],
+                           stream_.sampleRate,
+                           renderFormat->nSamplesPerSec,
+                           stream_.bufferSize,
+                           convBufferSize,
+                           stream_.deviceFormat[OUTPUT] );
+
       // Push callback buffer into outputBuffer
-      callbackPushed = renderBuffer.pushBuffer( stream_.deviceBuffer,
-                                                stream_.bufferSize * stream_.nDeviceChannels[OUTPUT],
+      callbackPushed = renderBuffer.pushBuffer( convBuffer,
+                                                convBufferSize * stream_.nDeviceChannels[OUTPUT],
                                                 stream_.deviceFormat[OUTPUT] );
     }
     else {
@@ -5099,6 +5331,8 @@ Exit:
   CoTaskMemFree( captureFormat );
   CoTaskMemFree( renderFormat );
 
+  free ( convBuffer );
+
   CoUninitialize();
 
   // update stream state