Commit c486f571 authored by Jonathan Hadida's avatar Jonathan Hadida
Browse files

Improve helptext + check inputs + add options to return best K particles

parent e6ef8151
function best = particle_filter( objective_fun, range, variance, Nparticles, Nkeep, Nnew, Ncycles )
function [best_particles,best_scores] = particle_filter( objective_fun, range, variance, Nparticles, Ncycles, varargin )
%
% [best_particles,best_scores] = PARTICLE_FILTER( objective_fun, range, variance, Nparticles, Ncycles )
%
% ----------
% INPUTS:
%
% objective_fun
% Function handle which, given a particle, returns a positive score to be maximised.
......@@ -6,44 +11,86 @@ function best = particle_filter( objective_fun, range, variance, Nparticles, Nke
% range
% A Ndim x 2 array in which each row specifies the min/max value in each dimension.
%
% variance
% A Ndim x 1 array with variances in each dimension used during resampling.
%
% Nparticles
% Some large integer.
% Some large integer (>50).
%
% Ncycles
% Number of cycles to run (>10).
%
% ----------
% OPTIONS (key/value pairs):
%
% Nkeep
% Number of particles serving as seeds during resampling.
% DEFAULT: Nparticles/2
%
% Nnew
% Number of new particles to generate at each cycle.
% DEFAULT: Nparticles/10
%
% Ncycles
% Number of cycles to run.
% BestK
% Number of best particles to return in output.
% DEFAULT: 10
%
% ----------
% OUTPUTS:
%
% best_particles
% The particle(s) that obtained the best score(s) throughout all cycles.
%
% best_scores
% The corresponding scores.
%
% parse options
opt = struct();
arg = { 'Nkeep', ceil(Nparticles/2), 'Nnew', ceil(Nparticles/10), 'BestK', 10 }; % defaults
arg = [ arg, varargin ];
narg = numel(arg)/2;
for i = 1:narg
opt.(arg{ 2*i-1 }) = arg{2*i};
end
% check inputs
assert( Nparticles > 50, 'More particles please.' );
assert( Ncycles > 10, 'More cycles please.' );
assert( all(variance >= 0), 'Variances should be positive.' );
assert( all(range(:,2) > range(:,1)), 'Second column should be larger than first column in range.' );
assert( Nkeep < Nparticles );
assert( Nnew < Nparticles );
assert( opt.Nkeep < Nparticles, 'Nkeep cant be larger than Nparticles.' );
assert( opt.Nnew < Nparticles, 'Nnew cant be larger than Nparticles.' );
assert( opt.BestK < Nparticles, 'BestK cant be larger than Nparticles.' );
% initialise the particles
particles = initialise( range, Nparticles );
scores = evaluate( objective_fun, particles );
% initialise output
[best_score,best] = max(scores);
best = particles(:,best);
[~,order] = sort(scores,'descend');
best_scores = scores(order(1:opt.BestK));
best_particles = particles(:,order(1:opt.BestK));
for c = 1:Ncycles
% resample
keep = importance_sampling( scores, Nkeep );
particles = [ resample( particles(:,keep), variance, Nparticles-Nnew ), initialise(range,Nnew) ];
keep = importance_sampling( scores, opt.Nkeep );
particles = [ resample( particles(:,keep), variance, Nparticles-opt.Nnew ), initialise( range, opt.Nnew ) ];
% evaluate
scores = evaluate( objective_fun, particles );
% update best particle
[~,best_index] = max(scores);
if scores(best_index) > best_score
best_score = scores(best_index);
best = particles(:,best_index);
end
% update best particles
[~,order] = sort([scores,best_scores],'descend');
order = order(1:opt.BestK);
id_keep = order( order > Nparticles ) - Nparticles;
id_new = order( order <= Nparticles );
best_scores = [ best_scores(id_keep), scores(id_new) ];
best_particles = [ best_particles(:,id_keep), particles(:,id_new) ];
end
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment