Skip to content

Commit

Permalink
Make distribution sensitive to vowels/consonants (fix #62)
Browse files Browse the repository at this point in the history
  • Loading branch information
jcreedcmu committed Nov 20, 2023
1 parent 182e524 commit 53bda21
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 66 deletions.
66 changes: 49 additions & 17 deletions src/core/distribution.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,19 @@ import { produce } from "../util/produce";
import { next_rand } from "../util/util";
import { CoreState, GameState } from "./state";

// contains 26 values. The probability of a letter being picked next
// is proportional to e^{-βE}
export type Energies = number[];
type LetterClass = 0 | 1;

export function getClass(index: number): LetterClass {
return [0, 4, 8, 14, 20].includes(index) ? 0 : 1;
}

export type Energies =
{
// contains 26 values. The probability of a letter being picked next
// is proportional to e^{-βE}
byLetter: number[],
byClass: number[], // vowel, consonant
}

// contains 26 values, which sum to 1
export type Probs = number[];
Expand All @@ -16,13 +26,16 @@ const default_beta = 2;
// proportionality constant for how much to adjust energies by
const default_increment = 4;

export function distributionOf(energies: Energies, beta: number): Probs {
const unnormalizedProbs = energies.map(energy => Math.exp(-beta * energy));
// Takes in a list of energies, returns a list of probabilities
export function distributionOf(energies: number[], beta: number): number[] {
const minEnergy = Math.min(...energies);
const calibratedEnergies = energies.map(e => e - minEnergy);
const unnormalizedProbs = calibratedEnergies.map(energy => Math.exp(-beta * energy));
const sum = unnormalizedProbs.reduce((a, b) => a + b);
return unnormalizedProbs.map(prob => prob / sum);
}

export const letterDistribution: Record<string, number> = {
const letterDistribution: Record<string, number> = {
a: 10,
b: 3,
c: 4,
Expand Down Expand Up @@ -52,16 +65,28 @@ export const letterDistribution: Record<string, number> = {
};

const alphabet = Object.keys(letterDistribution).sort();
const letterDistributionNumbers = alphabet.map(letter => letterDistribution[letter]);

function mkClassDistribution(): number[] {
let counts = [0, 0];
Object.keys(letterDistribution).forEach(k => {
counts[getClass(k.charCodeAt(0) - 97)] += letterDistribution[k]
});
return counts;
}

const classDistribution = mkClassDistribution();

export function initialEnergies(): Energies {
return initialEnergiesOf(letterDistribution, default_beta);
return {
byLetter: initialEnergiesOf(letterDistributionNumbers, default_beta),
byClass: initialEnergiesOf(classDistribution, default_beta),
};
}

export function initialEnergiesOf(letterDistribution: Record<string, number>, beta: number): Energies {
export function initialEnergiesOf(distribution: number[], beta: number): number[] {
const energies: number[] = [];
return Object.keys(letterDistribution).sort().map(letter =>
(1 / beta) * Math.log(1 / letterDistribution[letter])
);
return distribution.map(v => (1 / beta) * Math.log(1 / v));
}

export function getSample(seed0: number, probs: Probs): { sample: number, seed: number } {
Expand All @@ -88,13 +113,20 @@ export function getSample(seed0: number, probs: Probs): { sample: number, seed:
return { sample, seed };
}

export function getLetterSampleOf(seed0: number, energies0: Energies, letterDistribution: Record<string, number>, alphabet: string[], beta: number, increment: number): { seed: number, letter: string, energies: Energies } {
const { seed, sample } = getSample(seed0, distributionOf(energies0, beta));
const letter = alphabet[sample];
const energies = produce(energies0, e => { e[sample] += increment / letterDistribution[letter] });
return { seed, energies, letter };
export function getLetterSampleOf(seed0: number, energies0: Energies, letterDistribution: Record<string, number>, classDistribution: number[], alphabet: string[], beta: number, increment: number): { seed: number, letter: string, energies: Energies } {

const { seed: seed1, sample: classSample } = getSample(seed0, distributionOf(energies0.byClass, beta));
const modifiedLetterEnergies = energies0.byLetter.map((energy, ix) => getClass(ix) == classSample ? energy : Infinity);
const { seed: seed2, sample: letterSample } = getSample(seed1, distributionOf(modifiedLetterEnergies, beta));
const letter = alphabet[letterSample];
const energies = produce(energies0, e => {
e.byClass[classSample] += increment / classDistribution[classSample];
e.byLetter[letterSample] += increment / letterDistribution[letter];
});

return { seed: seed2, energies, letter };
}

export function getLetterSample(seed0: number, energies0: Energies): { seed: number, letter: string, energies: Energies } {
return getLetterSampleOf(seed0, energies0, letterDistribution, alphabet, default_beta, default_increment);
return getLetterSampleOf(seed0, energies0, letterDistribution, classDistribution, alphabet, default_beta, default_increment);
}
59 changes: 31 additions & 28 deletions src/ui/instructions.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -82,34 +82,37 @@ function exampleState(): GameState {
invalidWords: [],
tile_entities: {},
connectedSet: mkGridOf([]),
energies: [
-0.20972339977922094,
0.7840271889992784,
1.3068528194400546,
0.3068528194400547,
0.5636791674230779,
1.6534264097200273,
0.3068528194400547,
0.7840271889992784,
0.6791654891096679,
4,
4,
1.1041202653859723,
0.7840271889992784,
0.4602792291600821,
0.4602792291600821,
0.7840271889992784,
4,
0.4602792291600821,
-0.0047189562170500965,
-0.039720770839917874,
0.79528104378295,
1.6534264097200273,
1.6534264097200273,
4,
1.6534264097200273,
0
],
energies: {
byLetter: [
-0.20972339977922094,
0.7840271889992784,
1.3068528194400546,
0.3068528194400547,
0.5636791674230779,
1.6534264097200273,
0.3068528194400547,
0.7840271889992784,
0.6791654891096679,
4,
4,
1.1041202653859723,
0.7840271889992784,
0.4602792291600821,
0.4602792291600821,
0.7840271889992784,
4,
0.4602792291600821,
-0.0047189562170500965,
-0.039720770839917874,
0.79528104378295,
1.6534264097200273,
1.6534264097200273,
4,
1.6534264097200273,
0
],
byClass: [0, 0,],
},
seed: 1533311107,
canvas_from_world: {
scale: {
Expand Down
51 changes: 30 additions & 21 deletions tests/test-distribution.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { distributionOf, getLetterSampleOf, getSample, initialEnergiesOf } from '../src/core/distribution';
import { Energies, distributionOf, getClass, getLetterSampleOf, getSample, initialEnergiesOf } from '../src/core/distribution';

describe('getSample', () => {

Expand All @@ -25,7 +25,7 @@ describe('getSample', () => {
describe('currentDistributionOf', () => {

test(`should work as expected`, () => {
expect(distributionOf(initialEnergiesOf({ 'a': 1, 'e': 4 }, 1), 1)).toEqual([0.2, 0.8]);
expect(distributionOf(initialEnergiesOf([1, 4], 1), 1)).toEqual([0.2, 0.8]);
});

});
Expand All @@ -35,30 +35,39 @@ describe('getLetterSample', () => {
test(`should work approximately as expected`, () => {

const samples = [];
let energies = [10, 0];
const letterDistribution = { a: 3, b: 1 };
const alphabet = ['a', 'b'];
let seed = 124;
let energies: Energies = { byLetter: [0, 0, 0, 0, 0], byClass: [1, 1] };
const letterDistribution = { A: 1, B: 3, C: 2, D: 1, E: 3 };
const classDistribution = [
1, // vowels,
2, // consonants,
];
const alphabet = ['A', 'B', 'C', 'D', 'E'];
const counts: Record<string, number> = {};
let seed = 121;
let sample = 0;
for (let i = 0; i < 40; i++) {
const b = getLetterSampleOf(seed, energies, letterDistribution, alphabet, 1, 1);
for (let i = 0; i < 1000; i++) {
const b = getLetterSampleOf(seed, energies, letterDistribution, classDistribution, alphabet, 1, 10);
seed = b.seed;
energies = b.energies;
samples.push(b.letter);
counts[b.letter] = (counts[b.letter] ?? 0) + 1;
}
expect(samples).toEqual([
// We heavily biased the energy vector in b's favor so we see a
// bunch of b's initially
'b', 'b', 'b', 'b', 'b', 'b', 'b', 'b', 'b', 'b',
// But eventually that bias runs out and we see roughly a 3-to-1
// ratio of a's to b's after we catch up
'a', 'b', 'a', 'a', 'a',
'a', 'b', 'a', 'a', 'b',
'a', 'a', 'b', 'a', 'a',
'a', 'b', 'b', 'a', 'a',
'a', 'b', 'a', 'a', 'b',
'a', 'a', 'a', 'a', 'a'
]);

expect(Math.abs(3 - counts.E / counts.A)).toBeLessThan(0.1);
expect(Math.abs(2 - (counts.B + counts.C + counts.D) / (counts.A + counts.E))).toBeLessThan(0.1);


});
});

describe('getClass', () => {
describe('should be correct', () => {
expect(getClass('A'.charCodeAt(0) - 65)).toBe(0);
expect(getClass('E'.charCodeAt(0) - 65)).toBe(0);
expect(getClass('I'.charCodeAt(0) - 65)).toBe(0);
expect(getClass('O'.charCodeAt(0) - 65)).toBe(0);
expect(getClass('U'.charCodeAt(0) - 65)).toBe(0);
expect(getClass('Y'.charCodeAt(0) - 65)).toBe(1);
expect(getClass('C'.charCodeAt(0) - 65)).toBe(1);
});
});

0 comments on commit 53bda21

Please sign in to comment.