#!/usr/bin/python

import glob
import math
import random

stop_word = {}
f = open("stop_words.txt", "r")
s = f.readline()
while s != "":
    stop_word[s.rstrip()] = True
    s = f.readline()
f.close()


post_files = glob.glob("20_newsgroups/*/*")

random.shuffle(post_files)

training_set = post_files[0:18000]
test_set = post_files[18001:19997]

##################### Calculates P(H) 

posts = {}

p_group = {}

for post_file in training_set:
    values = post_file.split("/")
    group = values[-2]

    # Initializes counter for groups
    if not p_group.has_key(group):
        p_group[group] = 0.0

    # Counts the number of posts in each group
    p_group[group] += 1.0

    # Creates an empty list of posts for each news group
    if not posts.has_key(group):
        posts[group] = []

    # Adds words from the post to group
    f = open(post_file, "r")    
    posts[group] += f.read().replace("\n", " ").replace(".", "").split(" ")
    f.close()

# Calculates probabilities
for group in p_group.keys():
    p_group[group] /= len(training_set)

##################### Removes irrelevant words

# Counts the number of times each word occurs in all the posts combined
word_count = {}
for group in posts.keys():
    for word in posts[group]:
        if not word_count.has_key(word):
            word_count[word] = 0
        word_count[word] += 1

# Include only words that occurs more than 2 times and that are not a stop word
vocabulary = {}
for word in word_count.keys():
    if word_count[word] > 2 and not stop_word.has_key(word):
        vocabulary[word] = True

##################### Calculates P(O | H) 

p_word_given_group = {}
for group in posts.keys():
    
    p_word_given_group[group] = {}

    # Counts the number of words
    for word in vocabulary.keys():
        p_word_given_group[group][word] = 1.0

    for word in posts[group]:
        if vocabulary.has_key(word):
            p_word_given_group[group][word] += 1.0

    # Calculates probabilities
    for word in vocabulary.keys():
        p_word_given_group[group][word] /= len(posts[group]) + len(vocabulary)
        
##################### Classify posts

errors = 0.0
total = 0.0
    
for post_file in test_set:
    values = post_file.split("/")
    true_group = values[-2]

    f = open(post_file, "r")    
    post_to_be_classified = f.read().replace("\n", " ").replace(".", "").split(" ")
    f.close()

    # Finds group with max P(O | H) * P(H)
    max_group = 0
    max_p = 1
    for candidate_group in posts.keys():
        # Calculates P(O | H) * P(H) for candidate group
        p = math.log(p_group[candidate_group])
        for word in post_to_be_classified:
            if vocabulary.has_key(word):
                p += math.log(p_word_given_group[candidate_group][word])

        if p > max_p or max_p == 1:
            max_p = p
            max_group = candidate_group

    total += 1.0
    if true_group != max_group:
        errors += 1.0

    if total % 100 == 0:
        print "%.3f" % (1.0 - errors/total)
