Commit c486f571 authored by Jonathan Hadida's avatar Jonathan Hadida
Browse files

Improve helptext + check inputs + add options to return best K particles

parent e6ef8151
function best = particle_filter( objective_fun, range, variance, Nparticles, Nkeep, Nnew, Ncycles ) function [best_particles,best_scores] = particle_filter( objective_fun, range, variance, Nparticles, Ncycles, varargin )
%
% [best_particles,best_scores] = PARTICLE_FILTER( objective_fun, range, variance, Nparticles, Ncycles )
%
% ----------
% INPUTS:
% %
% objective_fun % objective_fun
% Function handle which, given a particle, returns a positive score to be maximised. % Function handle which, given a particle, returns a positive score to be maximised.
...@@ -6,44 +11,86 @@ function best = particle_filter( objective_fun, range, variance, Nparticles, Nke ...@@ -6,44 +11,86 @@ function best = particle_filter( objective_fun, range, variance, Nparticles, Nke
% range % range
% A Ndim x 2 array in which each row specifies the min/max value in each dimension. % A Ndim x 2 array in which each row specifies the min/max value in each dimension.
% %
% variance
% A Ndim x 1 array with variances in each dimension used during resampling.
%
% Nparticles % Nparticles
% Some large integer. % Some large integer (>50).
%
% Ncycles
% Number of cycles to run (>10).
%
% ----------
% OPTIONS (key/value pairs):
% %
% Nkeep % Nkeep
% Number of particles serving as seeds during resampling. % Number of particles serving as seeds during resampling.
% DEFAULT: Nparticles/2
% %
% Nnew % Nnew
% Number of new particles to generate at each cycle. % Number of new particles to generate at each cycle.
% DEFAULT: Nparticles/10
% %
% Ncycles % BestK
% Number of cycles to run. % Number of best particles to return in output.
% DEFAULT: 10
%
% ----------
% OUTPUTS:
%
% best_particles
% The particle(s) that obtained the best score(s) throughout all cycles.
%
% best_scores
% The corresponding scores.
%
assert( Nkeep < Nparticles ); % parse options
assert( Nnew < Nparticles ); opt = struct();
arg = { 'Nkeep', ceil(Nparticles/2), 'Nnew', ceil(Nparticles/10), 'BestK', 10 }; % defaults
arg = [ arg, varargin ];
narg = numel(arg)/2;
for i = 1:narg
opt.(arg{ 2*i-1 }) = arg{2*i};
end
% check inputs
assert( Nparticles > 50, 'More particles please.' );
assert( Ncycles > 10, 'More cycles please.' );
assert( all(variance >= 0), 'Variances should be positive.' );
assert( all(range(:,2) > range(:,1)), 'Second column should be larger than first column in range.' );
assert( opt.Nkeep < Nparticles, 'Nkeep cant be larger than Nparticles.' );
assert( opt.Nnew < Nparticles, 'Nnew cant be larger than Nparticles.' );
assert( opt.BestK < Nparticles, 'BestK cant be larger than Nparticles.' );
% initialise the particles % initialise the particles
particles = initialise( range, Nparticles ); particles = initialise( range, Nparticles );
scores = evaluate( objective_fun, particles ); scores = evaluate( objective_fun, particles );
% initialise output % initialise output
[best_score,best] = max(scores); [~,order] = sort(scores,'descend');
best = particles(:,best); best_scores = scores(order(1:opt.BestK));
best_particles = particles(:,order(1:opt.BestK));
for c = 1:Ncycles for c = 1:Ncycles
% resample % resample
keep = importance_sampling( scores, Nkeep ); keep = importance_sampling( scores, opt.Nkeep );
particles = [ resample( particles(:,keep), variance, Nparticles-Nnew ), initialise(range,Nnew) ]; particles = [ resample( particles(:,keep), variance, Nparticles-opt.Nnew ), initialise( range, opt.Nnew ) ];
% evaluate % evaluate
scores = evaluate( objective_fun, particles ); scores = evaluate( objective_fun, particles );
% update best particle % update best particles
[~,best_index] = max(scores); [~,order] = sort([scores,best_scores],'descend');
if scores(best_index) > best_score order = order(1:opt.BestK);
best_score = scores(best_index); id_keep = order( order > Nparticles ) - Nparticles;
best = particles(:,best_index); id_new = order( order <= Nparticles );
end
best_scores = [ best_scores(id_keep), scores(id_new) ];
best_particles = [ best_particles(:,id_keep), particles(:,id_new) ];
end end
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment