aboutsummaryrefslogtreecommitdiffstats
path: root/trunk/infrastructure/net.appjet.common.sars/sars.scala
diff options
context:
space:
mode:
Diffstat (limited to 'trunk/infrastructure/net.appjet.common.sars/sars.scala')
-rw-r--r--trunk/infrastructure/net.appjet.common.sars/sars.scala348
1 files changed, 348 insertions, 0 deletions
diff --git a/trunk/infrastructure/net.appjet.common.sars/sars.scala b/trunk/infrastructure/net.appjet.common.sars/sars.scala
new file mode 100644
index 0000000..f91b292
--- /dev/null
+++ b/trunk/infrastructure/net.appjet.common.sars/sars.scala
@@ -0,0 +1,348 @@
+/**
+ * Copyright 2009 Google Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS-IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package net.appjet.common.sars;
+
+import scala.collection.mutable.HashSet;
+import java.net.{Socket, ServerSocket, InetSocketAddress, SocketException};
+import java.io.{DataInputStream, DataOutputStream, IOException}
+import java.util.concurrent.atomic.AtomicBoolean;
+
+trait SarsMessageHandler {
+ def handle(s: String): Option[String] = None;
+ def handle(b: Array[byte]): Option[Array[byte]] = None;
+}
+
+class SarsException(m: String) extends RuntimeException(m);
+class ChannelClosedUnexpectedlyException extends SarsException("Sars channel closed unexpectedly.");
+class BadAuthKeyException extends SarsException("Sars authKey not accepted.");
+class NotAuthenticatedException extends SarsException("Sars must authenticate before sending message.");
+class UnknownTypeException(t: String) extends SarsException("Sars type unknown: "+t);
+
+private[sars] trait SarsMessageReaderWriter {
+ def byteArray = 1;
+ def utf8String = 2;
+
+ def inputStream: DataInputStream;
+ def outputStream: DataOutputStream;
+
+ def readMessage: Option[Any] = {
+ val messageType = inputStream.readInt;
+ if (messageType == byteArray) {
+ try {
+ val len = inputStream.readInt;
+ val bytes = new Array[byte](len);
+ inputStream.readFully(bytes);
+ Some(bytes);
+ } catch {
+ case ioe: IOException => None;
+ }
+ } else if (messageType == utf8String) {
+ try {
+ Some(inputStream.readUTF);
+ } catch {
+ case ioe: IOException => None;
+ }
+ } else {
+ throw new UnknownTypeException("type "+messageType);
+ }
+ }
+ def readString: Option[String] = {
+ val m = readMessage;
+ m.filter(_.isInstanceOf[String]).asInstanceOf[Option[String]];
+ }
+ def readBytes: Option[Array[byte]] = {
+ val m = readMessage;
+ m.filter(_.isInstanceOf[Array[byte]]).asInstanceOf[Option[Array[byte]]];
+ }
+ def writeMessage(bytes: Array[byte]) {
+ outputStream.writeInt(byteArray);
+ outputStream.writeInt(bytes.length);
+ outputStream.write(bytes);
+ }
+ def writeMessage(string: String) {
+ outputStream.writeInt(utf8String);
+ outputStream.writeUTF(string);
+ }
+}
+
+class SarsClient(authKey: String, host: String, port: Int) {
+
+ class SarsClientHandler(s: Socket) {
+ val readerWriter = new SarsMessageReaderWriter {
+ val inputStream = new DataInputStream(s.getInputStream());
+ val outputStream = new DataOutputStream(s.getOutputStream());
+ }
+ var authenticated = false;
+
+ def auth() {
+ val challenge = readerWriter.readString;
+ if (challenge.isEmpty) {
+ throw new ChannelClosedUnexpectedlyException;
+ }
+ readerWriter.writeMessage(SimpleSHA1(authKey+challenge.get));
+ val res = readerWriter.readString;
+ if (res.isEmpty || res.get != "ok") {
+ println(res.get);
+ throw new BadAuthKeyException;
+ }
+ authenticated = true;
+ }
+
+ def message[T](q: T, writer: T => Unit, reader: Unit => Option[T]): T = {
+ if (! authenticated) {
+ throw new NotAuthenticatedException;
+ }
+ try {
+ writer(q);
+ val res = reader();
+ if (res.isEmpty) {
+ throw new ChannelClosedUnexpectedlyException;
+ }
+ res.get;
+ } catch {
+ case e => {
+ if (! s.isClosed()) {
+ s.close();
+ }
+ throw e;
+ }
+ }
+ }
+
+ def message(s: String): String =
+ message[String](s, readerWriter.writeMessage, Unit => readerWriter.readString);
+
+ def message(b: Array[byte]): Array[byte] =
+ message[Array[byte]](b, readerWriter.writeMessage, Unit => readerWriter.readBytes);
+
+ def close() {
+ if (! s.isClosed) {
+ s.close();
+ }
+ }
+ }
+
+ var socket: Socket = null;
+ var connectTimeout = 0;
+ var readTimeout = 0;
+
+ def setConnectTimeout(timeout: Int) {
+ connectTimeout = timeout;
+ }
+ def setReadTimeout(timeout: Int) {
+ readTimeout = timeout;
+ }
+
+ var client: SarsClientHandler = null;
+ def connect() {
+ if (socket != null && ! socket.isClosed) {
+ socket.close();
+ }
+ socket = new Socket();
+ socket.connect(new InetSocketAddress(host, port), connectTimeout);
+ socket.setSoTimeout(readTimeout);
+ client = new SarsClientHandler(socket);
+ client.auth();
+ }
+
+ def message(q: String) = {
+ if (! socket.isConnected || socket.isClosed) {
+ connect();
+ }
+ client.message(q);
+ }
+
+ def message(b: Array[byte]) = {
+ if (! socket.isConnected || socket.isClosed) {
+ connect();
+ }
+ client.message(b);
+ }
+
+ def close() {
+ if (client != null) {
+ client.close();
+ }
+ }
+}
+
+class SarsServer(authKey: String, handler: SarsMessageHandler, host: Option[String], port: Int) {
+
+ // handles a single client.
+ class SarsServerHandler(cs: Socket) extends Runnable {
+ var thread: Thread = null;
+ var running = new AtomicBoolean(false);
+
+ def run() {
+ try {
+ thread = Thread.currentThread();
+ if (running.compareAndSet(false, true)) {
+ val readerWriter = new SarsMessageReaderWriter {
+ val inputStream = new DataInputStream(cs.getInputStream());
+ val outputStream = new DataOutputStream(cs.getOutputStream());
+ }
+ val challenge = Math.random*1e20;
+
+ readerWriter.writeMessage(String.valueOf(challenge));
+ val res = readerWriter.readString;
+ if (res.isEmpty || res.get != SimpleSHA1(authKey+challenge)) {
+ readerWriter.writeMessage("invalid key");
+ } else {
+ readerWriter.writeMessage("ok");
+ while (running.get()) {
+ val q = readerWriter.readMessage;
+ if (q.isEmpty) {
+ running.set(false);
+ } else {
+ q.get match {
+ case s: String => readerWriter.writeMessage(handler.handle(s).getOrElse(""));
+ case b: Array[byte] =>
+ readerWriter.writeMessage(handler.handle(b).getOrElse(new Array[byte](0)));
+ case x: AnyRef => throw new UnknownTypeException(x.getClass.getName);
+ }
+ }
+ }
+ }
+ }
+ } catch {
+ case e => { }
+ } finally {
+ cs.close();
+ }
+ }
+
+ def stop() {
+ if (running.compareAndSet(true, false)) {
+ thread.interrupt();
+ }
+ }
+ }
+
+ val ss = new ServerSocket(port);
+ if (host.isDefined) {
+ ss.bind(InetSocketAddress.createUnresolved(host.get, port));
+ }
+ var running = new AtomicBoolean(false);
+ var hasRun = false;
+ var serverThread: Thread = null;
+ val clients = new HashSet[SarsServerHandler];
+ var daemon = false;
+ val server = this;
+
+ def start() {
+ if (hasRun)
+ throw new RuntimeException("Can't reuse server.");
+ hasRun = true;
+ if (running.compareAndSet(false, true)) {
+ serverThread = new Thread() {
+ override def run() {
+ while(running.get()) {
+ val cs = try {
+ ss.accept();
+ } catch {
+ case e: SocketException => {
+ if (running.get()) {
+ println("socket exception.");
+ e.printStackTrace();
+ if (! ss.isClosed) {
+ ss.close();
+ }
+ return;
+ } else { // was closed by user.
+ return;
+ }
+ }
+ case e: IOException => {
+ println("i/o error");
+ e.printStackTrace();
+ ss.close();
+ return;
+ }
+ }
+ val client = new SarsServerHandler(cs);
+ server.synchronized {
+ clients += client;
+ }
+ (new Thread(client)).start();
+ }
+ }
+ }
+ if (daemon)
+ serverThread.setDaemon(true);
+ serverThread.start();
+ } else {
+ throw new RuntimeException("WTF, fool? Server's running already.");
+ }
+ }
+
+ def stop() {
+ if (running.compareAndSet(true, false)) {
+ if (! ss.isClosed) {
+ ss.close();
+ }
+ server.synchronized {
+ for (client <- clients) {
+ client.stop();
+ }
+ }
+ } else {
+ throw new RuntimeException("Not running.");
+ }
+ }
+
+ def join() {
+ serverThread.join();
+ }
+}
+
+object test {
+ def main(args: Array[String]) {
+ val handler = new SarsMessageHandler {
+ override def handle(s: String) = {
+ println("SERVER: "+s);
+ if (s == "hello!") {
+ Some("hey there.");
+ } else {
+ None;
+ }
+ }
+ override def handle(b: Array[byte]) = {
+ var actually = new String(b, "UTF-8");
+ println("SERVER: "+actually);
+ if (actually == "hello!") {
+ Some("hey there.".getBytes("UTF-8"));
+ } else {
+ None;
+ }
+ }
+ }
+
+ val server = new SarsServer("nopassword", handler, None, 9001);
+ server.start();
+
+ val client = new SarsClient("nopassword", "localhost", 9001);
+ client.connect();
+ println("CLIENT: "+client.message("hello!"));
+ println("CLIENT: "+client.message("goodbye!"));
+ println("CLIENT: "+new String(client.message("hello!".getBytes("UTF-8")), "UTF-8"));
+ println("CLIENT: "+new String(client.message("goodbye!".getBytes("UTF-8")), "UTF-8"));
+ client.close();
+ server.stop();
+ server.join();
+ println("done.");
+ }
+}