/*
 *	Knuth-Morris-Pratt's search automat for N given strings
 *		(taken from ad2 super-server and pruned :-)
 * 
 *	(c) 1999, 2001, Robert Spalek <robert@ucw.cz>
 */

#include "sherlock/sherlock.h"
#include "lib/mempool.h"
#include "lib/lists.h"
#include "sherlock/tagged-text.h"
#include "lib/unicode.h"
#include "charset/unicat.h"
#include "lang/kmp.h"

#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include <alloca.h>

#define	TRACE(level, mask...)	if (0) fprintf(stderr, mask)

struct kmp *
kmp_new(struct mempool *mp, int words_len, uns modify_flags)
{
	struct kmp *kmp = mp_alloc_zero(mp, sizeof(struct kmp));
	kmp->mp = mp;
	kmp->modify_flags = modify_flags;
	kmp->words_len = words_len;
	int size = words_len;
	kmp->g.count = 1;
	kmp->g.size = size;
	kmp->g.sons = mp_alloc_zero(mp, size * sizeof(struct list));
	init_list(kmp->g.sons + 0);
	if (words_len > 1)
		size = words_len * fls(words_len);
	else
		size = 1;
	kmp->g.hash_size = size;
	kmp->g.chain = mp_alloc_zero(mp, size * sizeof(struct kmp_transition *));
	kmp->f = mp_alloc_zero(mp, words_len * sizeof(kmp_state_t));
	kmp->out = mp_alloc_zero(mp, words_len * sizeof(struct kmp_output *));
	return kmp;
}

static inline uns
transition_hashf(struct kmp_transitions *l UNUSED, struct kmp_transition *tr)
{
	return tr->from + (tr->c << 16);
}

static inline int
transition_compare(struct kmp_transition *a, struct kmp_transition *b)
{
	if (a->from == b->from && a->c == b->c)
		return 0;
	else
		return 1;
}

static inline struct kmp_transition **
transition_search(struct kmp_transitions *l, struct kmp_transition *tr)
{
	uns hf = transition_hashf(l, tr) % l->hash_size;
	struct kmp_transition **last = l->chain + hf;
	while (*last && transition_compare(*last, tr))
		last = &(*last)->next;
	ASSERT(last);
	return last;
}

/*
 * The only merge operation is that son includes output of his father (and also
 * his father,...), so we can merge the link-lists.
 */
static void
merge_output(struct kmp_output **target, struct kmp_output *src)
{
	while (*target)
		target = &(*target)->next;
	*target = src;
}

static struct kmp_output *
new_output(struct kmp *kmp, uns id)
{
	struct kmp_output *out = mp_alloc(kmp->mp, sizeof(struct kmp_output));
	out->next = NULL;
	out->id = id;
	return out;
}

static inline kmp_char_t
translate_char(uns c, uns modify_flags)
{
	if (!c)
		return 0;
	if (c >= 0x80000000)
		return CONTROL_CHAR;
	if (modify_flags & MF_TOLOWER)
		c = Utolower(c);
	if (modify_flags & MF_UNACCENT)
		c = Uunaccent(c);
	if (modify_flags & MF_ONLYALPHA && !Ualpha(c))
		return CONTROL_CHAR;
	return c;
}

static inline void
get_char(const byte **str, kmp_char_t *c, uns modify_flags)
{
	while (1)
	{
		uns w;
		kmp_char_t new_c;
		GET_TAGGED_CHAR((*str), w);
		new_c = translate_char(w, modify_flags);
		if (new_c != CONTROL_CHAR || *c != CONTROL_CHAR)
		{
			*c = new_c;
			return;
		}
	}
}

void
kmp_enter_string(struct kmp *kmp, const byte *str, uns id)
{
	struct kmp_transition tr, **prev;
	struct kmp_output *new_out;
	const byte *orig_str = str;
	kmp_char_t c = 'a';

	tr.next = NULL;
	tr.from = 0;
	TRACE(20, "kmp.c: Entering string %s", str);
	get_char(&str, &c, kmp->modify_flags);
	if (!c)
		return;
	while (c)
	{
		tr.c = c;
		prev = transition_search(&kmp->g, &tr);
		if (!*prev)
			break;
		tr.from = (*prev)->to;
		get_char(&str, &c, kmp->modify_flags);
	}
	while (c)
	{
		*prev = mp_alloc_zero(kmp->mp, sizeof(struct kmp_transition));
		tr.to = kmp->g.count++;
		**prev = tr;
		add_tail(kmp->g.sons + tr.from, &(*prev)->n);
		init_list(kmp->g.sons + tr.to);
		get_char(&str, &c, kmp->modify_flags);
		tr.from = tr.to;
		tr.c = c;
		prev = transition_search(&kmp->g, &tr);
		ASSERT(!*prev);
	}
	if (kmp->out[tr.from])
		TRACE(5, "kmp.c: string %s is inserted more than once", orig_str);
	new_out = new_output(kmp, id);
	merge_output(kmp->out + tr.from, new_out);
}

static void
construct_f_out(struct kmp *kmp)
{
	kmp_state_t *fifo;
	int read, write;
	struct kmp_transition *son;

	fifo = alloca(kmp->words_len * sizeof(kmp_state_t));
	read = write = 0;
	kmp->f[0] = 0;
	WALK_LIST(son, kmp->g.sons[0])
	{
		ASSERT(son->from == 0);
		kmp->f[son->to] = 0;
		fifo[write++] = son->to;
	}
	while (read != write)
	{
		kmp_state_t r, s, t;
		r = fifo[read++];
		WALK_LIST(son, kmp->g.sons[r])
		{
			struct kmp_transition tr, **prev;
			ASSERT(son->from == r);
			tr.c = son->c;
			s = son->to;
			fifo[write++] = s;
			t = kmp->f[r];
			while (1)
			{
				tr.from = t;
				prev = transition_search(&kmp->g, &tr);
				if (*prev || !tr.from)
					break;
				t = kmp->f[t];
			}
			kmp->f[s] = *prev ? (*prev)->to : 0;
			merge_output(kmp->out + s, kmp->out[ kmp->f[s] ]);
		}
	}
}

void
kmp_build(struct kmp *kmp)
{
	ASSERT(kmp->g.count <= kmp->words_len);
	construct_f_out(kmp);
	if (kmp->words_len > 1)
		TRACE(0, "Built KMP with modify flags %d for total words len %d, it has %d nodes", kmp->modify_flags, kmp->words_len, kmp->g.count);
}

static inline void
add_result(struct list *nonzeroes, struct kmp_result *freq, struct kmp_output *out)
{
	for (; out; out = out->next)
		if (!freq[out->id].occur++)
			add_tail(nonzeroes, &freq[out->id].n);
}

void
kmp_search(struct kmp *kmp, const byte *str, struct list *nonzeroes, struct kmp_result *freq)
{
	kmp_state_t s = 0;
	kmp_char_t c = CONTROL_CHAR;
	struct kmp_transition tr, **prev;
	byte eof = 0;
	if (kmp->words_len <= 1)
		return;
	TRACE(20, "kmp.c: Searching string %s", str);
	while (1)
	{
		tr.from = s;
		tr.c = c;
		prev = transition_search(&kmp->g, &tr);
		while (tr.from && !*prev)
		{
			tr.from = kmp->f[ tr.from ];
			prev = transition_search(&kmp->g, &tr);
		}
		s = *prev ? (*prev)->to : 0;
		add_result(nonzeroes, freq, kmp->out[s]);
		if (eof)
			break;
		get_char(&str, &c, kmp->modify_flags);
		if (!c)
		{
			/* Insert CONTROL_CHAR at the beginning and at the end too.  */
			c = CONTROL_CHAR;
			eof = 1;
		}
	}
}
