有鑑於這題都沒人發題解,我來發一下吧。
基本上這題毫無反應就是個苦工題。
大概只要把 statement 的 type 們都定義好,照著定義寫就可以。
由於計算量並不大,我們並不用好好維持記憶體池,反正跑完就 release 了。
剩下來的就是要用力的 parse input ,不過因為他是個 S expression ,都還算好寫。
少數需要注意的大概只有 display 之類的東西可能不能寫死成 statement ,
因為在測資中有將 display 當作 lambda 傳的情況,這在題目敘述中貌似沒寫清楚。
總之用力寫一下大概也要寫個一個半小時... 賽中不太容易過。
#include <cstdio>
#include <cstdlib>
#include <string>
#include <vector>
#include <map>
#include <typeinfo>
using namespace std;
struct lambda_t;
struct object_t {
void *data;
object_t(int v) : data(new int(v)) {}
object_t(void *_data) : data(_data) {}
int as_int() { return *static_cast<int*>(data); }
lambda_t* as_lambda() { return static_cast<lambda_t*>(data); }
};
typedef map<string, object_t*> env_t;
struct stmt_t {
virtual object_t *eval(env_t &renv) = 0;
};
struct lambda_t {
env_t env;
stmt_t *body;
vector<string> args;
lambda_t(stmt_t *_body, const vector<string> &_args, const env_t &_env) : env(_env), body(_body), args(_args) {}
virtual object_t *call(env_t &renv, vector<object_t*> &vals) {
env_t nenv = renv;
for (env_t::iterator it = env.begin(); it != env.end(); it++)
nenv[it->first] = it->second;
for (int i = 0; i < (int)vals.size(); i++)
nenv[args[i]] = vals[i];
return body->eval(nenv);
}
};
struct term_stmt_t : public stmt_t {
string ident;
term_stmt_t(string _ident) : ident(_ident) {}
object_t *eval(env_t &renv) {
return isdigit(ident[0]) ? new object_t(strtol(ident.c_str(), NULL, 10)) : renv[ident];
}
};
struct if_stmt_t : public stmt_t {
stmt_t *cond, *stmt1, *stmt2;
if_stmt_t(stmt_t *_cond, stmt_t *_stmt1, stmt_t *_stmt2) : cond(_cond), stmt1(_stmt1), stmt2(_stmt2) {}
object_t *eval(env_t &renv) {
return cond->eval(renv)->as_int() ? stmt1->eval(renv) : stmt2->eval(renv);
}
};
struct define_stmt_t : public stmt_t {
string name;
stmt_t *stmt;
define_stmt_t(string _name, stmt_t* _stmt) : name(_name), stmt(_stmt) {}
object_t *eval(env_t &renv) {
object_t *ret = stmt->eval(renv);
return renv[name] = ret;
}
};
struct lambda_stmt_t : public stmt_t {
stmt_t *body;
vector<string> args;
lambda_stmt_t(stmt_t *_body, vector<string> &_args) : body(_body), args(_args) {}
object_t *eval(env_t &renv) {
return new object_t(new lambda_t(body, args, renv));
}
};
struct call_stmt_t : public stmt_t {
stmt_t *func;
vector<stmt_t*> args;
call_stmt_t(stmt_t *_func, vector<stmt_t*> &_args) : func(_func), args(_args) {}
object_t *eval(env_t &renv) {
vector<object_t*> vals;
lambda_t *lmb = func->eval(renv)->as_lambda();
for (int i = 0; i < (int)args.size(); i++)
vals.push_back(args[i]->eval(renv));
return lmb->call(renv, vals);
}
};
typedef object_t *(*handler_t)(vector<object_t*>&);
struct native_lambda_t : public lambda_t {
handler_t hdr;
native_lambda_t(handler_t _hdr) : lambda_t(NULL, vector<string>(), env_t()), hdr(_hdr) {}
virtual object_t *call(env_t &renv, vector<object_t*> &vals) { return hdr(vals); }
};
object_t *add_hdr(vector<object_t*> &vals) {
return new object_t(vals[0]->as_int() + vals[1]->as_int());
}
object_t *sub_hdr(vector<object_t*> &vals) {
return new object_t(vals[0]->as_int() - vals[1]->as_int());
}
object_t *lt_hdr(vector<object_t*> &vals) {
return new object_t(vals[0]->as_int() < vals[1]->as_int());
}
object_t *begin_hdr(vector<object_t*> &vals) {
return vals.back();
}
object_t *display_hdr(vector<object_t*> &vals) {
printf("%d\n", vals[0]->as_int());
return new object_t(0);
}
int next_char() {
static char buf[10010], *ptr = buf;
if (!*ptr) {
while (true) {
if (fgets(buf, sizeof(buf), stdin) == NULL) throw -1;
if (buf[0] == ';') fputs(buf, stdout);
else break;
}
ptr = buf;
}
return *ptr++;
}
string next_token() {
static int ch, lch = EOF;
ch = lch == EOF ? next_char() : lch;
lch = EOF;
while (isspace(ch)) ch = next_char();
if (ch == '(') return "(";
if (ch == ')') return ")";
string token;
do {
token += ch;
ch = next_char();
} while (!isspace(ch) && ch != ')' && ch != '(');
lch = ch;
return token;
}
stmt_t *next_stmt() {
string token = next_token();
if (token == ")") return NULL;
if (token == "(") {
string fname = "";
stmt_t *fstmt = next_stmt(), *ret;
try {
fname = dynamic_cast<term_stmt_t&>(*fstmt).ident;
} catch (std::bad_cast e) {}
if (fname == "define") {
string name = next_token();
ret = new define_stmt_t(name, next_stmt());
next_token(); // ")"
} else if (fname == "lambda") {
next_token(); // "("
vector<string> args;
for (string str = next_token(); str != ")"; str = next_token())
args.push_back(str);
ret = new lambda_stmt_t(next_stmt(), args);
next_token(); // ")"
} else {
vector<stmt_t*> stmts;
for (stmt_t *stmt = next_stmt(); stmt; stmt = next_stmt())
stmts.push_back(stmt);
if (fname == "if") {
ret = new if_stmt_t(stmts[0], stmts[1], stmts[2]);
} else {
ret = new call_stmt_t(fstmt, stmts);
}
}
return ret;
} else {
return new term_stmt_t(token);
}
}
int main() {
env_t env;
env["+"] = new object_t(new native_lambda_t(add_hdr));
env["-"] = new object_t(new native_lambda_t(sub_hdr));
env["<"] = new object_t(new native_lambda_t(lt_hdr));
env["begin"] = new object_t(new native_lambda_t(begin_hdr));
env["display"] = new object_t(new native_lambda_t(display_hdr));
try {
while (true) next_stmt()->eval(env);
} catch (int e) {}
}