Commit 49c30e1e by laurencehunt

parent 61d4ddda
 ... ... @@ -31,6 +31,6 @@ nTrials = length(opt1Rewarded); %% STUDENTS - complete this code to finish the reinforcement learning model for t = 1:nTrials %loop over trials delta(t) = %%COMPLETE THIS LINE using opt1Rewarded, probOpt1 and equation 1 %%; % prediction error probOpt1(t+1) = %%COMPLETE THIS LINE using probOpt1, delta, alpha and equation 2 %%; % prediction for next trial delta(t) = opt1Rewarded(t) - probOpt1(t); %%COMPLETE THIS LINE using opt1Rewarded, probOpt1 and equation 1 %%; % prediction error probOpt1(t+1) = probOpt1(t) + alpha*delta(t); %%COMPLETE THIS LINE using probOpt1, delta, alpha and equation 2 %%; % prediction for next trial end \ No newline at end of file
 ... ... @@ -6,9 +6,12 @@ fixedProb = 0.8; % this is the true probability that green is rewarded startingProb = 0.5; % this defines the model's estimated pronbabilty on the very first trial alpha = 0.05; % this is the model's learning rate alpha = 0.1; % this is the model's learning rate nTrials = 200; % this is the number of trials trueProbability = ones(1,nTrials)*fixedProb; %reward probability on each trial %trueProbability = ones(1,nTrials)*fixedProb; %reward probability on each trial D = load('schedule.mat'); trueProbability = D.trueProbabilityStored; nTrials = length(trueProbability); % now, we'll simulate whether green was rewarded on every trial opt1Rewarded(1:nTrials) = rand(1,nTrials) < trueProbability; ... ... @@ -45,7 +48,7 @@ legend(plotIndex, {'True Probability' 'Trial outcomes' 'RL model probability'}); %% 3. How many trials do we look back into the past with different alphas? alpha = 0.15; alpha = 0.01; T = 25; for t = 1:T ... ...
 function MakeSimFitPlots load('dataFitted'); load('complexsimulatedDataFitted'); load('simulatedDataFitted'); load('shortsimDataFitted'); complexTruealphaS=complexsimulatedData.StableLearningrates; complexTruealphaV=complexsimulatedData.VolatileLearningrates; complexFitalphaS=complexsimulatedData.LearningRateStable; complexFitalphaV=complexsimulatedData.LearningRateVolatile; complexrS=corrcoef(complexFitalphaS,complexTruealphaS); complexrV=corrcoef(complexFitalphaV,complexTruealphaV); figure('color',[ 1 1 1],'name','Ground Truth Alphas and Simulation Fit');hold on; set(gca,'Fontsize',14); subplot(3,2,1);plot(complexTruealphaS,complexFitalphaS,'.');hold on; plot([0 1],[0 1],'k');xlim([0 1]);ylim([0 1]);title(['Complex Stable (r=' num2str(complexrS(1,end)) ')']);set(gca,'Fontsize',14); xlabel('True Alpha');ylabel('Simulated Alpha') subplot(3,2,2);plot(complexTruealphaV,complexFitalphaV,'.');hold on; plot([0 1],[0 1],'k');xlim([0 1]);ylim([0 1]);title(['Complex Volatile (r=' num2str(complexrV(1,end)) ')']);set(gca,'Fontsize',14); xlabel('True Alpha');ylabel('Simulated Alpha') simpleTruealphaS=simulatedData.Learningrates; simpleTruealphaV=simulatedData.Learningrates; simpleFitalphaS=simulatedData.LearningRateStable; simpleFitalphaV=simulatedData.LearningRateVolatile; simplerS=corrcoef(simpleFitalphaS,simpleTruealphaS); simplerV=corrcoef(simpleFitalphaV,simpleTruealphaV); subplot(3,2,3);plot(simpleTruealphaS,simpleFitalphaS,'.');xlim([0 1]);ylim([0 1]);title(['Simple Stable (r=' num2str(simplerS(1,end)) ')']);set(gca,'Fontsize',14);hold on; plot([0 1],[0 1],'k'); xlabel('True Alpha');ylabel('Simulated Alpha') subplot(3,2,4);plot(simpleTruealphaV,simpleFitalphaV,'.');xlim([0 1]);ylim([0 1]);title(['Simple Volatile (r=' num2str(simplerV(1,end)) ')']);set(gca,'Fontsize',14);hold on; plot([0 1],[0 1],'k'); xlabel('True Alpha');ylabel('Simulated Alpha') shortsimpleTruealphaS=shortsimData.Learningrates; shortsimpleTruealphaV=shortsimData.Learningrates; shortsimpleFitalphaS=shortsimData.LearningRateStable; shortsimpleFitalphaV=shortsimData.LearningRateVolatile; shortsimplerS=corrcoef(shortsimpleTruealphaS,shortsimpleFitalphaS); shortsimplerV=corrcoef(shortsimpleTruealphaV,shortsimpleFitalphaV); subplot(3,2,5);plot(shortsimpleTruealphaS,shortsimpleFitalphaS,'.');xlim([0 1]);ylim([0 1]);title(['Short Simple Stable (r=' num2str(shortsimplerS(1,end)) ')']); set(gca,'Fontsize',14);hold on; plot([0 1],[0 1],'k'); xlabel('True Alpha');ylabel('Simulated Alpha') subplot(3,2,6);plot(shortsimpleTruealphaV,shortsimpleFitalphaV,'.');xlim([0 1]);ylim([0 1]);title(['Short Simple Volatile (r=' num2str(shortsimplerV(1,end)) ')']);set(gca,'Fontsize',14);hold on; plot([0 1],[0 1],'k'); xlabel('True Alpha');ylabel('Simulated Alpha') xposition(:,1)=0.85:5.85;xposition(:,2)=1.15:6.15;Labels={'C Sim Alpha S';'C True Alpha S';'C Sim Alpha V';'C True Alpha V';... 'S Sim Alpha S';'S True Alpha S';'S Sim Alpha V';'S True Alpha V';'ShortS Sim Alpha S';'ShortS True Alpha S';'ShortS Sim Alpha V';'ShortS True Alpha V'}; figure('color',[ 1 1 1],'name','Simulated and Ground Truth Alphas');hold on; set(gca,'Fontsize',10);ylim([0 .45]);set(gca,'XtickLabel',Labels,'XTick',sort(xposition(:),'ascend')) try set(gca,'XTickLabelRotation',45,'Fontsize',16) catch end bar([mean(complexFitalphaS) mean(complexFitalphaV) mean(simpleFitalphaS) mean(simpleFitalphaV) mean(shortsimpleFitalphaS) mean(shortsimpleFitalphaV);... mean(complexTruealphaS) mean(complexTruealphaV) mean(simpleTruealphaS) mean(simpleTruealphaV) mean(shortsimpleTruealphaS) mean(shortsimpleTruealphaV)]');hold on; errorbar(xposition,[mean(complexFitalphaS) mean(complexFitalphaV) mean(simpleFitalphaS) mean(simpleFitalphaV) mean(shortsimpleFitalphaS) mean(shortsimpleFitalphaV);... mean(complexTruealphaS) mean(complexTruealphaV) mean(simpleTruealphaS) mean(simpleTruealphaV) mean(shortsimpleTruealphaS) mean(shortsimpleTruealphaV)]',[... std(complexFitalphaS) std(complexFitalphaV) std(simpleFitalphaS) std(simpleFitalphaV) std(shortsimpleFitalphaS) std(shortsimpleFitalphaV);... std(complexTruealphaS) std(complexTruealphaV) std(simpleTruealphaS) std(simpleTruealphaV) std(shortsimpleTruealphaS) std(shortsimpleTruealphaV)]'./(length(shortsimpleFitalphaV).^.5),'.k' ) ylabel('Learning rate (alpha)') figure('color',[ 1 1 1],'name','Simulation Fits');hold on; set(gca,'Fontsize',14); subplot(3,2,1);hist(complexFitalphaS,100);hold on;ylim([0 20]) title(['Complex Stable (r=' num2str(complexrS(1,end)) ')']);set(gca,'Fontsize',14); xlabel('Simulated Alpha'); subplot(3,2,2);hist(complexFitalphaV,100);hold on;ylim([0 20]) title(['Complex Volatile (r=' num2str(complexrV(1,end)) ')']);set(gca,'Fontsize',14); xlabel('Simulated Alpha'); subplot(3,2,3);hist(simpleFitalphaS,100);title(['Simple Stable (r=' num2str(simplerS(1,end)) ')']);set(gca,'Fontsize',14);hold on;ylim([0 20]) xlabel('Simulated Alpha'); subplot(3,2,4);hist(simpleFitalphaV,100);title(['Simple Volatile (r=' num2str(simplerV(1,end)) ')']);set(gca,'Fontsize',14);hold on;ylim([0 20]) xlabel('Simulated Alpha'); subplot(3,2,5);hist(shortsimpleFitalphaS,100);title(['Short Simple Stable (r=' num2str(shortsimplerS(1,end)) ')']); set(gca,'Fontsize',14);hold on;ylim([0 20]) xlabel('Simulated Alpha'); subplot(3,2,6);hist(shortsimpleFitalphaV,100);title(['Short Simple Volatile (r=' num2str(shortsimplerV(1,end)) ')']);set(gca,'Fontsize',14);hold on;ylim([0 20]) xlabel('Simulated Alpha'); figure('color',[ 1 1 1],'name','Ground Truth Alphas');hold on; set(gca,'Fontsize',14); subplot(3,2,1);hist(complexTruealphaS,100);hold on;ylim([0 20]) title(['Complex Stable (r=' num2str(complexrS(1,end)) ')']);set(gca,'Fontsize',14); xlabel('Simulated Alpha'); subplot(3,2,2);hist(complexTruealphaV,100);hold on;ylim([0 20]) title(['Complex Volatile (r=' num2str(complexrV(1,end)) ')']);set(gca,'Fontsize',14); xlabel('Simulated Alpha'); subplot(3,2,3);hist(simpleTruealphaS,100);title(['Simple Stable (r=' num2str(simplerS(1,end)) ')']);set(gca,'Fontsize',14);hold on;ylim([0 20]) xlabel('Simulated Alpha'); subplot(3,2,4);hist(simpleTruealphaV,100);title(['Simple Volatile (r=' num2str(simplerV(1,end)) ')']);set(gca,'Fontsize',14);hold on;ylim([0 20]) xlabel('Simulated Alpha'); subplot(3,2,5);hist(shortsimpleTruealphaS,100);title(['Short Simple Stable (r=' num2str(shortsimplerS(1,end)) ')']); set(gca,'Fontsize',14);hold on;ylim([0 20]) xlabel('Simulated Alpha'); subplot(3,2,6);hist(shortsimpleTruealphaV,100);title(['Short Simple Volatile (r=' num2str(shortsimplerV(1,end)) ')']);set(gca,'Fontsize',14);hold on;ylim([0 20]) xlabel('Simulated Alpha');