/*
  predef.c
*/


#include <string.h>
#include <stdlib.h>
#include <ctype.h>
#include <assert.h>

#include "struct.h"
#include "status.h"
#include "predef.h"
#include "p97.h"
#include "clause.h"
#include "garb.h"
#include "unify.h"

typedef enum { ERROR = 1, VARIABLE = 0, VALUE = -1 } expr_stat_t;


static void eval(OBJ *p, int *val, expr_stat_t *stat);
static PREDEF *checkpredef(char *s, PREDEF *p);
static int cut(OBJ *st);
static int fail(OBJ *st);
static int integer(OBJ *st);
static int integer2(OBJ *st);
static int write(OBJ *st);
static int nl(OBJ *st);
static int nonvar(OBJ *st);
static int test(OBJ *st);
static int asserta(OBJ *st);
static int retract(OBJ *st);
static int is(OBJ *st);

static int ne(OBJ *st);
static int eq(OBJ *st);
static int le(OBJ *st);
static int lt(OBJ *st);
static int ge(OBJ *st);
static int gt(OBJ *st);
static int eval_arg(OBJ *st, int *val_lft, int *val_rgt);

static int plus(int ond1, int ond2);
static int multiply(int ond1, int ond2);
static int division(int ond1, int ond2);
static int mod(int ond1, int ond2);


PREDEF predefv[] = {
  { "!", cut },
  { "fail", fail },
  { "nonvar", nonvar },
  { "integer", integer },
  { "write", write },
  { "nl", nl },
  { "test", test },
  { "assert",asserta },
  { "retract", retract },
  { "is", is },
  { ">", gt },
  { ">=", ge },
  { "<", lt },
  { "<=", le },
  { "=:=", eq },
  { "=\\=", ne },
  { "", NULL }
};

PREDEF *ext_predefv = NULL;

static PREDEF *checkpredef(char *s, PREDEF *p)
{
  if (p)
    for (; p->name[0]; p++)
      if (!strcmp(p->name, s))
	return p;

  return NULL;
}

PREDEF *ispredef(char *s)
{
  PREDEF *p = predefv;

  if (p = checkpredef(s, predefv))
    return p;

  return checkpredef(s, ext_predefv);
}


int predef(OBJ *st)
{
  PREDEF *p;

  st = getobj(st);
  if (p = ispredef(st->name))
    return (p->func)(st);
  else
    return NOPREDEF;
}


static int cut(OBJ *st)
{
  STACK *tmp;

  for (tmp = sts_sp; tmp->data != ((STATUS *)sts_sp->data)->father; ) {
    ((STATUS *)tmp->data)->cutted = 1;
    tmp = tmp->prev;
  }

  ((STATUS *)sts_sp->data)->father->cutted = 1;
  return 1;
}


static int fail(OBJ *st)
{
  return 0;
}


static int nonvar(OBJ *st)
{
  return isvar_st(getobj(st->child)) ? 0 : 1;
}


static int test(OBJ *st)
{
  int sum;
  OBJ
    tmp,
    *C    = st->child,
    *Rout = st->child->brother,
    *A    = st->child->brother->brother,
    *B    = st->child->brother->brother->brother,
    *Rin  = st->child->brother->brother->brother->brother;

  if ((st->arity != 5) || !integer2(A) || !integer2(B) || !integer2(Rin))
    return 0;

  sum = atoi(getobj(A)->name)+atoi(getobj(B)->name)+atoi(getobj(Rin)->name);

  tmp.arity = 0;
  tmp.brother = 0;
  tmp.child = 0;

  assert(!isvar_st(getobj(C)));

  if (!(atoi(getobj(C)->name) == (sum % 10)))
    return 0;

  if (isvar_st(getobj(Rout))) {
    sprintf(tmp.name, "%d", sum / 10);
    link_var(getobj(Rout), &tmp);
  } else
    if (!(atoi(getobj(Rout)->name) == (sum / 10)))
      return 0;

  return 1;
}


static int integer(OBJ *st)
{
  return integer2(st->child);
}

static int integer2(OBJ *st)
{
  char *p;

  if (isconst_st(st = getobj(st))) {
    for (p = st->name; *p && isdigit(*p); p++);
    return *p ? 0 : 1;
  } else
    return 0;
}

static int write(OBJ *st)
{
  print_st(st->child);
  return 1;
}

static int nl(OBJ *st)
{
  printf("\n");
  return 1;
}

static int asserta(OBJ *st)
{
  DB *p = (DB *) mymalloc(sizeof(DB));
  char buf[LINE_L + 1];

  sprint_st(buf, st->child);
  sread_cl(p->line, buf);
  p->next = src;
  src = p;
}

static int retract(OBJ *st)
{
  DB* p;
  CLAUSE cl;

  if (p = find_cl(st->child, &cl, NULL)) {
    p->line[0] = '\0';
    destroy_st(cl.head);
    destroy_st(cl.body);
  }

  return 1;
}

static int is(OBJ *st)
{
  expr_stat_t s_lft, s_rgt;
  int val_lft, val_rgt;

  if (!isvar_st(getobj(st->child))) {
    eval(st->child, &val_lft, &s_lft);
    if (s_lft == ERROR)
      return 0;
  }

  eval(st->child->brother, &val_rgt, &s_rgt);

  if ((s_rgt != VALUE) || ((s_lft == VALUE) && (val_lft != val_rgt)))
    return 0;

  else if (s_lft == VALUE)
    return 1;

  else {
    OBJ p;

    sprintf(p.name, "%d", val_rgt);
    p.child = p.brother = NULL;
    p.arity = 0;
    return unify(st->child, &p);
  }
}

struct otr_t {
  char *otr;
  int (*handle)(int, int);
} otrv[] = {
  { "+", plus },
  { "*", multiply },
  { "//", division },
  { "mod", mod },
  { "", NULL }
};

static int plus(int ond1, int ond2)
{
  return ond1 + ond2;
}

static int multiply(int ond1, int ond2)
{
  return ond1 * ond2;
}

static int division(int ond1, int ond2)
{
  if (!ond2) {
    printf("Division by zero error!\n");
    exit(1);
  }

  return ond1 / ond2;
}

static int mod(int ond1, int ond2)
{
  if (!ond2) {
    printf("Division by zero error!\n");
    exit(1);
  }

  return ond1 % ond2;
}

static void eval(OBJ *p, int *val, expr_stat_t *stat)
{
  struct otr_t *o;
  int ond1, ond2;

  if (!p) {
    *stat = ERROR;
    return;
  }

  p = getobj(p);

  if (isvar_st(p)) {
    *stat = VARIABLE;
    return;
  }

  if (integer2(p)) {
    *stat = VALUE;
    *val = atoi(p->name);
    return;
  }

  eval(p->child, &ond1, stat);

  if (*stat != VALUE) {
    *stat = ERROR;
    return;
  }

  eval(p->child->brother, &ond2, stat);

  if (*stat != VALUE) {
    *stat = ERROR;
    return;
  }

  for (o = otrv; o->otr[0] && strcmp(o->otr, p->name); o++);

  if (o->otr[0]) {
    *stat = VALUE;
    *val = (*o->handle)(ond1, ond2);
  } else
    *stat = ERROR;

  return;
}

static int gt(OBJ *st)
{
  int val_lft, val_rgt;
  return (eval_arg(st, &val_lft, &val_rgt)) ? (val_lft > val_rgt) : 0;
}

static int ge(OBJ *st)
{
  int val_lft, val_rgt;
  return (eval_arg(st, &val_lft, &val_rgt)) ? (val_lft >= val_rgt) : 0;
}

static int lt(OBJ *st)
{
  int val_lft, val_rgt;
  return (eval_arg(st, &val_lft, &val_rgt)) ? (val_lft < val_rgt) : 0;
}

static int le(OBJ *st)
{
  int val_lft, val_rgt;
  return (eval_arg(st, &val_lft, &val_rgt)) ? (val_lft <= val_rgt) : 0;
}

static int eq(OBJ *st)
{
  int val_lft, val_rgt;
  return (eval_arg(st, &val_lft, &val_rgt)) ? (val_lft == val_rgt) : 0;
}

static int ne(OBJ *st)
{
  int val_lft, val_rgt;
  return (eval_arg(st, &val_lft, &val_rgt)) ? (val_lft != val_rgt) : 0;
}

static int eval_arg(OBJ *st, int *val_lft, int *val_rgt)
{
  OBJ *o1, *o2;
  expr_stat_t stat;

  eval(st->child, val_lft, &stat);

  if (stat != VALUE)
    return 0;

  eval(st->child->brother, val_rgt, &stat);

  return (stat == VALUE);
}