Commit f5c844d3 authored by Jonathan Hadida's avatar Jonathan Hadida
Browse files

Improve resampling and score-keeping

parent c486f571
......@@ -15,25 +15,21 @@ function [best_particles,best_scores] = particle_filter( objective_fun, range, v
% A Ndim x 1 array with variances in each dimension used during resampling.
%
% Nparticles
% Some large integer (>50).
% Some large integer (>30).
%
% Ncycles
% Number of cycles to run (>10).
% Number of cycles to run (>20).
%
% ----------
% OPTIONS (key/value pairs):
%
% Nkeep
% Number of particles serving as seeds during resampling.
% DEFAULT: Nparticles/2
%
% Nnew
% Number of new particles to generate at each cycle.
% DEFAULT: Nparticles/10
% DEFAULT: Nparticles/20
%
% BestK
% Number of best particles to return in output.
% DEFAULT: 10
% DEFAULT: 5
%
% ----------
% OUTPUTS:
......@@ -42,12 +38,12 @@ function [best_particles,best_scores] = particle_filter( objective_fun, range, v
% The particle(s) that obtained the best score(s) throughout all cycles.
%
% best_scores
% The corresponding scores.
% The record of all best scores at each cycle (matrix Ncycles x BestK, last row most recent).
%
% parse options
opt = struct();
arg = { 'Nkeep', ceil(Nparticles/2), 'Nnew', ceil(Nparticles/10), 'BestK', 10 }; % defaults
arg = { 'Nnew', ceil(Nparticles/20), 'BestK', 5 }; % defaults
arg = [ arg, varargin ];
narg = numel(arg)/2;
......@@ -56,47 +52,50 @@ function [best_particles,best_scores] = particle_filter( objective_fun, range, v
end
% check inputs
assert( Nparticles > 50, 'More particles please.' );
assert( Ncycles > 10, 'More cycles please.' );
assert( Nparticles > 30, 'More particles please.' );
assert( Ncycles > 20, 'More cycles please.' );
assert( all(variance >= 0), 'Variances should be positive.' );
assert( all(range(:,2) > range(:,1)), 'Second column should be larger than first column in range.' );
assert( opt.Nkeep < Nparticles, 'Nkeep cant be larger than Nparticles.' );
assert( opt.Nnew < Nparticles, 'Nnew cant be larger than Nparticles.' );
assert( opt.BestK < Nparticles, 'BestK cant be larger than Nparticles.' );
assert( opt.BestK > 0, 'BestK should be non-zero.' );
% initialise the particles
particles = initialise( range, Nparticles );
particles = uniform_sampling( range, Nparticles );
scores = evaluate( objective_fun, particles );
% initialise output
[~,order] = sort(scores,'descend');
best_scores = scores(order(1:opt.BestK));
best_particles = particles(:,order(1:opt.BestK));
[~,order] = sort(scores,'descend');
order = order(1:opt.BestK);
best_scores = ones(Ncycles+1,1) * scores(order);
best_particles = particles(:,order);
for c = 1:Ncycles
% resample
keep = importance_sampling( scores, opt.Nkeep );
particles = [ resample( particles(:,keep), variance, Nparticles-opt.Nnew ), initialise( range, opt.Nnew ) ];
particles = horzcat( ...
importance_sampling( particles, scores, variance, Nparticles-opt.Nnew ), ...
uniform_sampling( range, opt.Nnew ) );
% evaluate
scores = evaluate( objective_fun, particles );
% update best particles
[~,order] = sort([scores,best_scores],'descend');
[~,order] = sort([scores,best_scores(c,:)],'descend');
order = order(1:opt.BestK);
id_keep = order( order > Nparticles ) - Nparticles;
id_new = order( order <= Nparticles );
best_scores = [ best_scores(id_keep), scores(id_new) ];
best_particles = [ best_particles(:,id_keep), particles(:,id_new) ];
best_scores(c+1,:) = [ best_scores(c,id_keep), scores(id_new) ];
best_particles = [ best_particles(:,id_keep), particles(:,id_new) ];
end
end
function particles = initialise( range, n )
function particles = uniform_sampling( range, n )
Vmin = range(:,1);
Vmax = range(:,2);
......@@ -118,24 +117,21 @@ function scores = evaluate( objective_fun, particles )
end
function particles = importance_sampling( scores, n )
function particles = importance_sampling( particles, scores, variance, n )
% cumulative distribution
gamma = cumsum(scores);
gamma = [0 gamma] / gamma(end);
particles = rand(1,n);
% pick particles
index = rand(1,n);
for i = 1:n
particles(i) = find( particles(i) > gamma, 1, 'last' );
index(i) = find( index(i) > gamma, 1, 'last' );
end
end
function particles = resample( seeds, variance, n )
[Ndims,Nseeds] = size(seeds);
particles = seeds(:,randi( Nseeds, 1, n ));
% add noise
Ndims = size(particles,1);
noise = bsxfun( @times, randn(Ndims,n), variance );
particles = bsxfun( @plus, particles, noise );
particles = bsxfun( @plus, particles(:,index), noise );
end
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment