#ifndef DLIB_MKL_FFT_H
#define DLIB_MKL_FFT_H
#include <type_traits>
#include <mkl_dfti.h>
#include "fft_size.h"
#define DLIB_DFTI_CHECK_STATUS(s) \
if((s) != 0 && !DftiErrorClass((s), DFTI_NO_ERROR)) \
{ \
throw dlib::error(DftiErrorMessage((s))); \
}
namespace dlib
{
template<typename T>
void mkl_fft(const fft_size& dims, const std::complex<T>* in, std::complex<T>* out, bool is_inverse)
/*!
requires
- T must be either float or double
- dims represents the dimensions of both `in` and `out`
- dims.num_dims() > 0
- dims.num_dims() < 3
ensures
- performs an FFT on `in` and stores the result in `out`.
- if `is_inverse` is true, a backward FFT is performed,
otherwise a forward FFT is performed.
!*/
{
static_assert(std::is_floating_point<T>::value, "template parameter needs to be a floatint point type");
DLIB_ASSERT(dims.num_dims() > 0, "dims can't be empty");
DLIB_ASSERT(dims.num_dims() < 3, "we currently only support up to 2D FFT. Please submit an issue on github if 3D or above is required.");
constexpr DFTI_CONFIG_VALUE dfti_type = std::is_same<T,float>::value ? DFTI_SINGLE : DFTI_DOUBLE;
DFTI_DESCRIPTOR_HANDLE h;
MKL_LONG status;
if (dims.num_dims() == 1)
{
status = DftiCreateDescriptor(&h, dfti_type, DFTI_COMPLEX, 1, dims[0]);
DLIB_DFTI_CHECK_STATUS(status);
}
else
{
MKL_LONG size[] = {dims[0], dims[1]};
status = DftiCreateDescriptor(&h, dfti_type, DFTI_COMPLEX, 2, size);
DLIB_DFTI_CHECK_STATUS(status);
MKL_LONG strides[3];
strides[0] = 0;
strides[1] = size[1];
strides[2] = 1;
status = DftiSetValue(h, DFTI_INPUT_STRIDES, strides);
DLIB_DFTI_CHECK_STATUS(status);
status = DftiSetValue(h, DFTI_OUTPUT_STRIDES, strides);
DLIB_DFTI_CHECK_STATUS(status);
}
const DFTI_CONFIG_VALUE inplacefft = in == out ? DFTI_INPLACE : DFTI_NOT_INPLACE;
status = DftiSetValue(h, DFTI_PLACEMENT, inplacefft);
DLIB_DFTI_CHECK_STATUS(status);
// Unless we use sequential mode, the fft results are not correct.
status = DftiSetValue(h, DFTI_THREAD_LIMIT, 1);
DLIB_DFTI_CHECK_STATUS(status);
status = DftiCommitDescriptor(h);
DLIB_DFTI_CHECK_STATUS(status);
if (is_inverse)
status = DftiComputeBackward(h, (void*)in, (void*)out);
else
status = DftiComputeForward(h, (void*)in, (void*)out);
DLIB_DFTI_CHECK_STATUS(status);
status = DftiFreeDescriptor(&h);
DLIB_DFTI_CHECK_STATUS(status);
}
/*
* in has dims[0] * dims[1] * ... * dims[-2] * dims[-1] points
* out has dims[0] * dims[1] * ... * dims[-2] * (dims[-1]/2+1) points
*/
template<typename T>
void mkl_fftr(const fft_size& dims, const T* in, std::complex<T>* out)
/*!
requires
- T must be either float or double
- dims represent the dimensions of `in`
- `out` has dimensions {dims[0], dims[1], ..., dims[-2], dims[-1]/2+1}
- dims.num_dims() > 0
- dims.num_dims() <= 3
- dims.back() must be even
ensures
- performs a real FFT on `in` and stores the result in `out`.
!*/
{
static_assert(std::is_floating_point<T>::value, "template parameter needs to be a floatint point type");
DLIB_ASSERT(dims.num_dims() > 0, "dims can't be empty");
DLIB_ASSERT(dims.num_dims() < 3, "we currently only support up to 2D FFT. Please submit an issue on github if 3D or above is required.");
DLIB_ASSERT(dims.back() % 2 == 0, "last dimension needs to be even");
constexpr DFTI_CONFIG_VALUE dfti_type = std::is_same<T,float>::value ? DFTI_SINGLE : DFTI_DOUBLE;
DFTI_DESCRIPTOR_HANDLE h;
MKL_LONG status;
if (dims.num_dims() == 1)
{
status = DftiCreateDescriptor(&h, dfti_type, DFTI_REAL, 1, dims[0]);
DLIB_DFTI_CHECK_STATUS(status);
}
else
{
const long lastdim = dims[1]/2+1;
MKL_LONG size[] = {dims[0], dims[1]};
status = DftiCreateDescriptor(&h, dfti_type, DFTI_REAL, 2, size);
DLIB_DFTI_CHECK_STATUS(status);
{
MKL_LONG strides[3];
strides[0] = 0;
strides[1] = size[1];
strides[2] = 1;
status = DftiSetValue(h, DFTI_INPUT_STRIDES, strides);
DLIB_DFTI_CHECK_STATUS(status);
}
{
MKL_LONG strides[3];
strides[0] = 0;
strides[1] = lastdim;
strides[2] = 1;
status = DftiSetValue(h, DFTI_OUTPUT_STRIDES, strides);
DLIB_DFTI_CHECK_STATUS(status);
}
}
const DFTI_CONFIG_VALUE inplacefft = (void*)in == (void*)out ? DFTI_INPLACE : DFTI_NOT_INPLACE;
status = DftiSetValue(h, DFTI_PLACEMENT, inplacefft);
DLIB_DFTI_CHECK_STATUS(status);
status = DftiSetValue(h, DFTI_CONJUGATE_EVEN_STORAGE, DFTI_COMPLEX_COMPLEX);
DLIB_DFTI_CHECK_STATUS(status);
// Unless we use sequential mode, the fft results are not correct.
status = DftiSetValue(h, DFTI_THREAD_LIMIT, 1);
DLIB_DFTI_CHECK_STATUS(status);
status = DftiCommitDescriptor(h);
DLIB_DFTI_CHECK_STATUS(status);
status = DftiComputeForward(h, (void*)in, (void*)out);
DLIB_DFTI_CHECK_STATUS(status);
status = DftiFreeDescriptor(&h);
DLIB_DFTI_CHECK_STATUS(status);
}
/*
* in has dims[0] * dims[1] * ... * dims[-2] * (dims[-1]/2+1) points
* out has dims[0] * dims[1] * ... * dims[-2] * dims[-1] points
*/
template<typename T>
void mkl_ifftr(const fft_size& dims, const std::complex<T>* in, T* out)
/*!
requires
- T must be either float or double
- dims represent the dimensions of `out`
- `in` has dimensions {dims[0], dims[1], ..., dims[-2], dims[-1]/2+1}
- dims.num_dims() > 0
- dims.num_dims() <= 3
- dims.back() must be even
ensures
- performs an inverse real FFT on `in` and stores the result in `out`.
!*/
{
static_assert(std::is_floating_point<T>::value, "template parameter needs to be a floatint point type");
DLIB_ASSERT(dims.num_dims() > 0, "dims can't be empty");
DLIB_ASSERT(dims.num_dims() < 3, "we currently only support up to 2D FFT. Please submit an issue on github if 3D or above is required.");
DLIB_ASSERT(dims.back() % 2 == 0, "last dimension needs to be even");
constexpr DFTI_CONFIG_VALUE dfti_type = std::is_same<T,float>::value ? DFTI_SINGLE : DFTI_DOUBLE;
DFTI_DESCRIPTOR_HANDLE h;
MKL_LONG status;
if (dims.num_dims() == 1)
{
status = DftiCreateDescriptor(&h, dfti_type, DFTI_REAL, 1, dims[0]);
DLIB_DFTI_CHECK_STATUS(status);
}
else
{
const long lastdim = dims[1]/2+1;
MKL_LONG size[] = {dims[0], dims[1]};
status = DftiCreateDescriptor(&h, dfti_type, DFTI_REAL, 2, size);
DLIB_DFTI_CHECK_STATUS(status);
{
MKL_LONG strides[3];
strides[0] = 0;
strides[1] = lastdim;
strides[2] = 1;
status = DftiSetValue(h, DFTI_INPUT_STRIDES, strides);
DLIB_DFTI_CHECK_STATUS(status);
}
{
MKL_LONG strides[3];
strides[0] = 0;
strides[1] = dims[1];
strides[2] = 1;
status = DftiSetValue(h, DFTI_OUTPUT_STRIDES, strides);
DLIB_DFTI_CHECK_STATUS(status);
}
}
const DFTI_CONFIG_VALUE inplacefft = (void*)in == (void*)out ? DFTI_INPLACE : DFTI_NOT_INPLACE;
status = DftiSetValue(h, DFTI_PLACEMENT, inplacefft);
DLIB_DFTI_CHECK_STATUS(status);
status = DftiSetValue(h, DFTI_CONJUGATE_EVEN_STORAGE, DFTI_COMPLEX_COMPLEX);
DLIB_DFTI_CHECK_STATUS(status);
// Unless we use sequential mode, the fft results are not correct.
status = DftiSetValue(h, DFTI_THREAD_LIMIT, 1);
DLIB_DFTI_CHECK_STATUS(status);
status = DftiCommitDescriptor(h);
DLIB_DFTI_CHECK_STATUS(status);
status = DftiComputeBackward(h, (void*)in, (void*)out);
DLIB_DFTI_CHECK_STATUS(status);
status = DftiFreeDescriptor(&h);
DLIB_DFTI_CHECK_STATUS(status);
}
}
#endif // DLIB_MKL_FFT_H