-
Notifications
You must be signed in to change notification settings - Fork 1
/
predict.py
43 lines (26 loc) · 1.29 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from pyspark.mllib.recommendation import MatrixFactorizationModel
from pyspark import SparkContext
import sys
sc = SparkContext()
sc.setLogLevel("ERROR")
def predict(userID):
completeMovies = sc.textFile('datasets/ml-latest-small/movies.csv')
header2 = completeMovies.first()
completeMovies = completeMovies.filter(lambda line : line != header2)\
.map(lambda line : line.split(","))
model = MatrixFactorizationModel.load(sc, "target/model")
completeRDD = sc.textFile('datasets/ml-latest-small/ratings.csv')
header = completeRDD.first()
completeRDD = completeRDD.filter(lambda line : line != header)\
.map(lambda line : line.split(","))\
.map(lambda line : (line[0],line[1],line[2]))
userRatedMovies = completeRDD.filter(lambda line : line[0] == userID).map(lambda line : line[1]).collect()
userUnrated = completeMovies.filter(lambda line : line[0] not in userRatedMovies).map(lambda line : (userID,line[0]))
predict = model.predictAll(userUnrated).map(lambda line : [str(line[1]), line[2]]).sortBy(lambda line : line[1], ascending=False)
movies = predict.join(completeMovies)
output = movies.map(lambda line : line[1][1]).take(15)
for i in output:
print(i)
if __name__ == '__main__':
userID = sys.argv[1]
predict(userID)