clear all; close all; clc; % probabilities of the two bandit arms banditProbabilities(1) = 0.6; banditProbabilities(2) = 0.55; % set number of games numberOfGames = 1000; % set the number of trials numberOfTrials = 100; % set the learning rate learningRate = 0.2; % set the exploration rate explorationRate = 0.5; % determine which arm is the best [bestProbability, bestAction] = max(banditProbabilities); % set the number of arms numberOfArms = length(banditProbabilities); for gameCounter = 1:numberOfGames valueEstimated = zeros(1, numberOfArms); for trialCounter = 1 : numberOfTrials % Choose action if (rand() < explorationRate) % Choose one of the "n" actions at random currentAction = randi(1:numberOfArms); else % Choose greedy action if sum(valueEstimated) == 0 currentAction = randi(2); else [actionValue, currentAction] = max(valueEstimated); end end % decide if you made a good decision if currentAction == bestAction myChoices(trialCounter,gameCounter) = 1; else myChoices(trialCounter,gameCounter) = 0; end % decide if the trial was a win or loss reward = 0; if rand() < banditProbabilities(currentAction) reward = 1; end % compute a prediction error predictionError = reward - valueEstimated(currentAction); % update the value estimates valueEstimated(currentAction) = valueEstimated(currentAction) + predictionError * learningRate; % ensure the values do not explode (go above 1) if valueEstimated(currentAction) > 1 valueEstimated(currentAction) = 1; end % store the estimated values for plotting myValues(1:2,trialCounter,gameCounter) = valueEstimated; end end subplot(1,2,1); finalValues = mean(myValues,3); finalValues = mean(finalValues,2); bar(finalValues); title('Estimated Values'); subplot(1,2,2); meanValues = mean(myChoices,2); bar(meanValues); title('Accuracy');