% Spectral super-resolution in the presence of outliers and dense noise. 
% Recovery is carried out by a greedy approach that chooses either
% sinusoidal or spike atoms iteratively. Optionally the support of the 
% estimated spectrum is optimized at each iteration. This step, which 
% involves optimizing a nonconvex function to find a local minimum 
% enhances significantly the empirical performance of the algorithm.
% 
% For more information see the paper "Demixing Sines and Spikes: Spectral 
% Super-resolution in the Presence of Outliers" by C. Fernandez-Granda,
% G. Tang, X. Wang and L. Zheng

function [f_est, spike_est, coeffs_freq_rescaled, coeffs_spikes] ...
        = greedy_demixing(y, ind_data, thresh, max_iters, local_opt)

    n = length(y);
    coeffs_freq = [];
    coeffs_spikes = [];
    f_est = [];
    spike_est = [];
    
    residual = y;
    for iter = 1:max_iters
        
        disp(['Iteration: ' num2str(iter)])
        
        % Choose next atom
        
        % Sinusoidal atom that is most correlated with the residual
        [freq, freq_corr] = sinusoidal_atom(residual, ind_data);
        
        % Spike atom that is most correlated with the residual
        [spike_corr,spike_ind] = max(abs(residual));
        
        % Update the set of atoms
        
        if freq_corr > spike_corr
            
            f_est = [f_est, freq];
            disp(['New frequency: ' num2str(freq) ' Number of frequencies: ' num2str(length(f_est))])
        
        else
            
            spike_est = [spike_est, spike_ind];
            spike_est = unique(spike_est);
            disp(['New spike: ' num2str(spike_ind) ' Number of spikes: ' num2str(length(spike_est))])
        
        end
        
        % Prune atoms by removing an atom if the coefficient is smaller
        % than thresh
        [coeffs_freq,coeffs_spikes,~] = fit_coeffs(y, f_est, spike_est, ...
                                                   ind_data, false);
        disp(['Pruned frequencies: ' num2str(nnz(abs(coeffs_freq)< thresh))])
        f_est = f_est(find(abs(coeffs_freq) > thresh));
        
        disp(['Pruned spikes: ' num2str(nnz(abs(coeffs_spikes)< thresh))])    
        spike_est = spike_est(abs(coeffs_spikes)>thresh);
        
        if local_opt
            
            % Optimize the estimated frequencies by fixing the spike atoms
            % and applying local optimization
            f_opt = optimize_freq(y,f_est,spike_est,ind_data);
            disp('Local optimization')
            disp(['Old frequencies: ' num2str(f_est)])
            disp(['New frequencies: ' num2str(f_opt)])
            f_est = f_opt;
            
        end
        
        % Update residual
        [coeffs_freq,coeffs_spikes,fit] = fit_coeffs(y, f_est, spike_est, ind_data, true);
        residual = y - fit;
        
    end
    % In the measurement operator the frequency atoms are weighed by
    % sqrt(n)
    coeffs_freq_rescaled = coeffs_freq ./ sqrt(n); 
end

% Function to find the sinusoidal atoms with the most correlation, we 
% minimize the correlation locally after initializing using a fine grid
function [freq, freq_corr] = sinusoidal_atom(residual, ind_data)

    upsamp_N = 10 * length(residual);
    center = (upsamp_N - 1)/2 + 1;
    upsampled_fft = ifftshift(ifft(residual, upsamp_N));
    [~, grid_max]=max(abs(upsampled_fft));
    f_grid = (grid_max - center)/upsamp_N;
    % if f_grid > 0
    %     f_grid = 0.5 - f_grid;
    % end
    freq = fminsearchbnd(@(x) -corr_sinusoidal(x, ind_data, residual), ...
                         f_grid, -0.5, 0.5);
    
    freq_corr = corr_sinusoidal(freq, ind_data, residual);

end

% Function that computes the correlation with a sinusoidal atom
function res = corr_sinusoidal(freq, ind_data, residual)

    atom = exp(-1i * 2 * pi * ind_data' * freq)/ sqrt(length(residual));
    res = abs(atom'*residual);

end

% Function to fit the coefficient corresponding to each atom by least
% squares
function [coeffs_freq,coeffs_spikes, fit] = fit_coeffs(y, f_est, spike_est, ...
                                           ind_data, compute_fit)
    n = length(y);
    k_est = length(f_est);
    s_est = length(spike_est);
    
    if k_est > 0
        
        F_est = exp(-1i*2*pi*ind_data'*f_est)/sqrt(n);
        
    end
    
    if s_est > 0
        
        I_aux = eye(n);
        I_est = I_aux(:, spike_est);
    
    end
    
    fit = [];
    
    if s_est > 0 && k_est > 0
        
        coeffs = [F_est I_est] \ y;
        coeffs_freq = coeffs(1:k_est);
        coeffs_spikes = coeffs(k_est+1:end);
        
        if compute_fit
            
            fit = [F_est I_est] * coeffs;
        
        end
        
    else
        if k_est > 0
            
            coeffs_freq = F_est \ y;
            coeffs_spikes = [];
            
            if compute_fit
            
                fit = F_est * coeffs_freq;
        
            end     
               
        else
            coeffs_spikes = I_est \ y;
            coeffs_freq = [];
            
            if compute_fit
            
                fit = I_est * coeffs_spikes;
        
            end 
            
        end
    end
    
end

% Function to optimize frequency locations by minimizing the fit with the
% residual
function f_est = optimize_freq(y,f_ini,spike_est,ind_data)

    k_est = length(f_ini);
    
    % Optimization using fminsearchbnd
    f_est = fminsearchbnd(@(x) lossfun(x, y, ind_data, spike_est),... 
                               f_ini, -0.5 .* ones(1,k_est), 0.5 .* ones(1,k_est));
end

function res = lossfun(x, y, ind_data, spike_est)

    % The least squares fit will only include entries that don't have spikes
    % since for those entries the least-squares term can always be made
    % zero by choosing the spike amplitudes adequately
    n = length(y);
    no_spikes_ind = setdiff(1:n, spike_est);
    F_est = exp(-1i*2*pi*ind_data(no_spikes_ind)'*x') / sqrt(n);
    res = norm(y(no_spikes_ind) - F_est * (F_est \ y(no_spikes_ind)),2);
   
end

