-
Notifications
You must be signed in to change notification settings - Fork 6
/
utils.py
62 lines (38 loc) · 1.14 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
def get_parameters(model):
params = {}
if model == 'MNIST':
params['model'] = 'MNIST'
params['alpha'] = .7
params['Nmin'] = 5
params['Nmax'] = 100
params['h_dim'] = 256
params['g_dim'] = 512
params['H_dim'] = 128
elif model == 'Gauss2D':
params = {}
params['model'] = 'Gauss2D'
params['alpha'] = .7
params['sigma'] = 1 # std for the Gaussian noise around the cluster mean
params['lambda'] = 10 # std for the Gaussian prior that generates de centers of the clusters
params['Nmin'] = 5
params['Nmax'] = 100
params['x_dim'] = 2
params['h_dim'] = 256
params['g_dim'] = 512
params['H_dim'] = 128
else:
raise NameError('Unknown model '+ model)
return params
def relabel(cs):
cs = cs.copy()
d={}
k=0
for i in range(len(cs)):
j = cs[i]
if j not in d:
d[j] = k
k+=1
cs[i] = d[j]
return cs