Commit 49c30e1e authored by laurencehunt's avatar laurencehunt
Browse files

add session 2

parent 61d4ddda
......@@ -4,4 +4,6 @@
In the first tutorial (developed by Laurence), we introduce a simple reinforcement learning model for a task that is based around the paradigm of Behrens et al., 2007, but it has an additional (between-subjects) stress manipulation. This paradigm runs in Psychtoolbox, which needs to be downloaded for the paradigm to run properly: see http://psychtoolbox.org. The stress manipulation can be easily switched off, by setting trialvariables.playSound to be 0 for all trials (or by simply muting the sound).
In the second tutorial (developed by Nils), students learn about the basics of reinforcement learning, and the importance of simulating data from a model and parameter recovery from simulated data.
Please let the tutorial authors know if you encounter any unintentional errors. Please also feel free to contact us if you have any questions, although we may not be able to respond to all enquiries.
\ No newline at end of file
......@@ -31,6 +31,6 @@ nTrials = length(opt1Rewarded);
%% STUDENTS - complete this code to finish the reinforcement learning model
for t = 1:nTrials %loop over trials
delta(t) = %%COMPLETE THIS LINE using opt1Rewarded, probOpt1 and equation 1 %%; % prediction error
probOpt1(t+1) = %%COMPLETE THIS LINE using probOpt1, delta, alpha and equation 2 %%; % prediction for next trial
delta(t) = opt1Rewarded(t) - probOpt1(t); %%COMPLETE THIS LINE using opt1Rewarded, probOpt1 and equation 1 %%; % prediction error
probOpt1(t+1) = probOpt1(t) + alpha*delta(t); %%COMPLETE THIS LINE using probOpt1, delta, alpha and equation 2 %%; % prediction for next trial
end
\ No newline at end of file
......@@ -6,9 +6,12 @@
fixedProb = 0.8; % this is the true probability that green is rewarded
startingProb = 0.5; % this defines the model's estimated pronbabilty on the very first trial
alpha = 0.05; % this is the model's learning rate
alpha = 0.1; % this is the model's learning rate
nTrials = 200; % this is the number of trials
trueProbability = ones(1,nTrials)*fixedProb; %reward probability on each trial
%trueProbability = ones(1,nTrials)*fixedProb; %reward probability on each trial
D = load('schedule.mat');
trueProbability = D.trueProbabilityStored;
nTrials = length(trueProbability);
% now, we'll simulate whether green was rewarded on every trial
opt1Rewarded(1:nTrials) = rand(1,nTrials) < trueProbability;
......@@ -45,7 +48,7 @@ legend(plotIndex, {'True Probability' 'Trial outcomes' 'RL model probability'});
%% 3. How many trials do we look back into the past with different alphas?
alpha = 0.15;
alpha = 0.01;
T = 25;
for t = 1:T
......
function MakeSimFitPlots
load('dataFitted');
load('complexsimulatedDataFitted');
load('simulatedDataFitted');
load('shortsimDataFitted');
complexTruealphaS=complexsimulatedData.StableLearningrates;
complexTruealphaV=complexsimulatedData.VolatileLearningrates;
complexFitalphaS=complexsimulatedData.LearningRateStable;
complexFitalphaV=complexsimulatedData.LearningRateVolatile;
complexrS=corrcoef(complexFitalphaS,complexTruealphaS);
complexrV=corrcoef(complexFitalphaV,complexTruealphaV);
figure('color',[ 1 1 1],'name','Ground Truth Alphas and Simulation Fit');hold on; set(gca,'Fontsize',14);
subplot(3,2,1);plot(complexTruealphaS,complexFitalphaS,'.');hold on;
plot([0 1],[0 1],'k');xlim([0 1]);ylim([0 1]);title(['Complex Stable (r=' num2str(complexrS(1,end)) ')']);set(gca,'Fontsize',14);
xlabel('True Alpha');ylabel('Simulated Alpha')
subplot(3,2,2);plot(complexTruealphaV,complexFitalphaV,'.');hold on;
plot([0 1],[0 1],'k');xlim([0 1]);ylim([0 1]);title(['Complex Volatile (r=' num2str(complexrV(1,end)) ')']);set(gca,'Fontsize',14);
xlabel('True Alpha');ylabel('Simulated Alpha')
simpleTruealphaS=simulatedData.Learningrates;
simpleTruealphaV=simulatedData.Learningrates;
simpleFitalphaS=simulatedData.LearningRateStable;
simpleFitalphaV=simulatedData.LearningRateVolatile;
simplerS=corrcoef(simpleFitalphaS,simpleTruealphaS);
simplerV=corrcoef(simpleFitalphaV,simpleTruealphaV);
subplot(3,2,3);plot(simpleTruealphaS,simpleFitalphaS,'.');xlim([0 1]);ylim([0 1]);title(['Simple Stable (r=' num2str(simplerS(1,end)) ')']);set(gca,'Fontsize',14);hold on;
plot([0 1],[0 1],'k');
xlabel('True Alpha');ylabel('Simulated Alpha')
subplot(3,2,4);plot(simpleTruealphaV,simpleFitalphaV,'.');xlim([0 1]);ylim([0 1]);title(['Simple Volatile (r=' num2str(simplerV(1,end)) ')']);set(gca,'Fontsize',14);hold on;
plot([0 1],[0 1],'k');
xlabel('True Alpha');ylabel('Simulated Alpha')
shortsimpleTruealphaS=shortsimData.Learningrates;
shortsimpleTruealphaV=shortsimData.Learningrates;
shortsimpleFitalphaS=shortsimData.LearningRateStable;
shortsimpleFitalphaV=shortsimData.LearningRateVolatile;
shortsimplerS=corrcoef(shortsimpleTruealphaS,shortsimpleFitalphaS);
shortsimplerV=corrcoef(shortsimpleTruealphaV,shortsimpleFitalphaV);
subplot(3,2,5);plot(shortsimpleTruealphaS,shortsimpleFitalphaS,'.');xlim([0 1]);ylim([0 1]);title(['Short Simple Stable (r=' num2str(shortsimplerS(1,end)) ')']); set(gca,'Fontsize',14);hold on;
plot([0 1],[0 1],'k');
xlabel('True Alpha');ylabel('Simulated Alpha')
subplot(3,2,6);plot(shortsimpleTruealphaV,shortsimpleFitalphaV,'.');xlim([0 1]);ylim([0 1]);title(['Short Simple Volatile (r=' num2str(shortsimplerV(1,end)) ')']);set(gca,'Fontsize',14);hold on;
plot([0 1],[0 1],'k');
xlabel('True Alpha');ylabel('Simulated Alpha')
xposition(:,1)=0.85:5.85;xposition(:,2)=1.15:6.15;Labels={'C Sim Alpha S';'C True Alpha S';'C Sim Alpha V';'C True Alpha V';...
'S Sim Alpha S';'S True Alpha S';'S Sim Alpha V';'S True Alpha V';'ShortS Sim Alpha S';'ShortS True Alpha S';'ShortS Sim Alpha V';'ShortS True Alpha V'};
figure('color',[ 1 1 1],'name','Simulated and Ground Truth Alphas');hold on; set(gca,'Fontsize',10);ylim([0 .45]);set(gca,'XtickLabel',Labels,'XTick',sort(xposition(:),'ascend'))
try
set(gca,'XTickLabelRotation',45,'Fontsize',16)
catch
end
bar([mean(complexFitalphaS) mean(complexFitalphaV) mean(simpleFitalphaS) mean(simpleFitalphaV) mean(shortsimpleFitalphaS) mean(shortsimpleFitalphaV);...
mean(complexTruealphaS) mean(complexTruealphaV) mean(simpleTruealphaS) mean(simpleTruealphaV) mean(shortsimpleTruealphaS) mean(shortsimpleTruealphaV)]');hold on;
errorbar(xposition,[mean(complexFitalphaS) mean(complexFitalphaV) mean(simpleFitalphaS) mean(simpleFitalphaV) mean(shortsimpleFitalphaS) mean(shortsimpleFitalphaV);...
mean(complexTruealphaS) mean(complexTruealphaV) mean(simpleTruealphaS) mean(simpleTruealphaV) mean(shortsimpleTruealphaS) mean(shortsimpleTruealphaV)]',[...
std(complexFitalphaS) std(complexFitalphaV) std(simpleFitalphaS) std(simpleFitalphaV) std(shortsimpleFitalphaS) std(shortsimpleFitalphaV);...
std(complexTruealphaS) std(complexTruealphaV) std(simpleTruealphaS) std(simpleTruealphaV) std(shortsimpleTruealphaS) std(shortsimpleTruealphaV)]'./(length(shortsimpleFitalphaV).^.5),'.k' )
ylabel('Learning rate (alpha)')
figure('color',[ 1 1 1],'name','Simulation Fits');hold on; set(gca,'Fontsize',14);
subplot(3,2,1);hist(complexFitalphaS,100);hold on;ylim([0 20])
title(['Complex Stable (r=' num2str(complexrS(1,end)) ')']);set(gca,'Fontsize',14);
xlabel('Simulated Alpha');
subplot(3,2,2);hist(complexFitalphaV,100);hold on;ylim([0 20])
title(['Complex Volatile (r=' num2str(complexrV(1,end)) ')']);set(gca,'Fontsize',14);
xlabel('Simulated Alpha');
subplot(3,2,3);hist(simpleFitalphaS,100);title(['Simple Stable (r=' num2str(simplerS(1,end)) ')']);set(gca,'Fontsize',14);hold on;ylim([0 20])
xlabel('Simulated Alpha');
subplot(3,2,4);hist(simpleFitalphaV,100);title(['Simple Volatile (r=' num2str(simplerV(1,end)) ')']);set(gca,'Fontsize',14);hold on;ylim([0 20])
xlabel('Simulated Alpha');
subplot(3,2,5);hist(shortsimpleFitalphaS,100);title(['Short Simple Stable (r=' num2str(shortsimplerS(1,end)) ')']); set(gca,'Fontsize',14);hold on;ylim([0 20])
xlabel('Simulated Alpha');
subplot(3,2,6);hist(shortsimpleFitalphaV,100);title(['Short Simple Volatile (r=' num2str(shortsimplerV(1,end)) ')']);set(gca,'Fontsize',14);hold on;ylim([0 20])
xlabel('Simulated Alpha');
figure('color',[ 1 1 1],'name','Ground Truth Alphas');hold on; set(gca,'Fontsize',14);
subplot(3,2,1);hist(complexTruealphaS,100);hold on;ylim([0 20])
title(['Complex Stable (r=' num2str(complexrS(1,end)) ')']);set(gca,'Fontsize',14);
xlabel('Simulated Alpha');
subplot(3,2,2);hist(complexTruealphaV,100);hold on;ylim([0 20])
title(['Complex Volatile (r=' num2str(complexrV(1,end)) ')']);set(gca,'Fontsize',14);
xlabel('Simulated Alpha');
subplot(3,2,3);hist(simpleTruealphaS,100);title(['Simple Stable (r=' num2str(simplerS(1,end)) ')']);set(gca,'Fontsize',14);hold on;ylim([0 20])
xlabel('Simulated Alpha');
subplot(3,2,4);hist(simpleTruealphaV,100);title(['Simple Volatile (r=' num2str(simplerV(1,end)) ')']);set(gca,'Fontsize',14);hold on;ylim([0 20])
xlabel('Simulated Alpha');
subplot(3,2,5);hist(shortsimpleTruealphaS,100);title(['Short Simple Stable (r=' num2str(shortsimplerS(1,end)) ')']); set(gca,'Fontsize',14);hold on;ylim([0 20])
xlabel('Simulated Alpha');
subplot(3,2,6);hist(shortsimpleTruealphaV,100);title(['Short Simple Volatile (r=' num2str(shortsimplerV(1,end)) ')']);set(gca,'Fontsize',14);hold on;ylim([0 20])
xlabel('Simulated Alpha');
This diff is collapsed.
This diff is collapsed.
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment