/**
* Copyright 2009 Google Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS-IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package net.appjet.common.sars;
import scala.collection.mutable.HashSet;
import java.net.{Socket, ServerSocket, InetSocketAddress, SocketException};
import java.io.{DataInputStream, DataOutputStream, IOException}
import java.util.concurrent.atomic.AtomicBoolean;
trait SarsMessageHandler {
def handle(s: String): Option[String] = None;
def handle(b: Array[byte]): Option[Array[byte]] = None;
}
class SarsException(m: String) extends RuntimeException(m);
class ChannelClosedUnexpectedlyException extends SarsException("Sars channel closed unexpectedly.");
class BadAuthKeyException extends SarsException("Sars authKey not accepted.");
class NotAuthenticatedException extends SarsException("Sars must authenticate before sending message.");
class UnknownTypeException(t: String) extends SarsException("Sars type unknown: "+t);
private[sars] trait SarsMessageReaderWriter {
def byteArray = 1;
def utf8String = 2;
def inputStream: DataInputStream;
def outputStream: DataOutputStream;
def readMessage: Option[Any] = {
val messageType = inputStream.readInt;
if (messageType == byteArray) {
try {
val len = inputStream.readInt;
val bytes = new Array[byte](len);
inputStream.readFully(bytes);
Some(bytes);
} catch {
case ioe: IOException => None;
}
} else if (messageType == utf8String) {
try {
Some(inputStream.readUTF);
} catch {
case ioe: IOException => None;
}
} else {
throw new UnknownTypeException("type "+messageType);
}
}
def readString: Option[String] = {
val m = readMessage;
m.filter(_.isInstanceOf[String]).asInstanceOf[Option[String]];
}
def readBytes: Option[Array[byte]] = {
val m = readMessage;
m.filter(_.isInstanceOf[Array[byte]]).asInstanceOf[Option[Array[byte]]];
}
def writeMessage(bytes: Array[byte]) {
outputStream.writeInt(byteArray);
outputStream.writeInt(bytes.length);
outputStream.write(bytes);
}
def writeMessage(string: String) {
outputStream.writeInt(utf8String);
outputStream.writeUTF(string);
}
}
class SarsClient(authKey: String, host: String, port: Int) {
class SarsClientHandler(s: Socket) {
val readerWriter = new SarsMessageReaderWriter {
val inputStream = new DataInputStream(s.getInputStream());
val outputStream = new DataOutputStream(s.getOutputStream());
}
var authenticated = false;
def auth() {
val challenge = readerWriter.readString;
if (challenge.isEmpty) {
throw new ChannelClosedUnexpectedlyException;
}
readerWriter.writeMessage(SimpleSHA1(authKey+challenge.get));
val res = readerWriter.readString;
if (res.isEmpty || res.get != "ok") {
println(res.get);
throw new BadAuthKeyException;
}
authenticated = true;
}
def message[T](q: T, writer: T => Unit, reader: Unit => Option[T]): T = {
if (! authenticated) {
throw new NotAuthenticatedException;
}
try {
writer(q);
val res = reader();
if (res.isEmpty) {
throw new ChannelClosedUnexpectedlyException;
}
res.get;
} catch {
case e => {
if (! s.isClosed()) {
s.close();
}
throw e;
}
}
}
def message(s: String): String =
message[String](s, readerWriter.writeMessage, Unit => readerWriter.readString);
def message(b: Array[byte]): Array[byte] =
message[Array[byte]](b, readerWriter.writeMessage, Unit => readerWriter.readBytes);
def close() {
if (! s.isClosed) {
s.close();
}
}
}
var socket: Socket = null;
var connectTimeout = 0;
var readTimeout = 0;
def setConnectTimeout(timeout: Int) {
connectTimeout = timeout;
}
def setReadTimeout(timeout: Int) {
readTimeout = timeout;
}
var client: SarsClientHandler = null;
def connect() {
if (socket != null && ! socket.isClosed) {
socket.close();
}
socket = new Socket();
socket.connect(new InetSocketAddress(host, port), connectTimeout);
socket.setSoTimeout(readTimeout);
client = new SarsClientHandler(socket);
client.auth();
}
def message(q: String) = {
if (! socket.isConnected || socket.isClosed) {
connect();
}
client.message(q);
}
def message(b: Array[byte]) = {
if (! socket.isConnected || socket.isClosed) {
connect();
}
client.message(b);
}
def close() {
if (client != null) {
client.close();
}
}
}
class SarsServer(authKey: String, handler: SarsMessageHandler, host: Option[String], port: Int) {
// handles a single client.
class SarsServerHandler(cs: Socket) extends Runnable {
var thread: Thread = null;
var running = new AtomicBoolean(false);
def run() {
try {
thread = Thread.currentThread();
if (running.compareAndSet(false, true)) {
val readerWriter = new SarsMessageReaderWriter {
val inputStream = new DataInputStream(cs.getInputStream());
val outputStream = new DataOutputStream(cs.getOutputStream());
}
val challenge = Math.random*1e20;
readerWriter.writeMessage(String.valueOf(challenge));
val res = readerWriter.readString;
if (res.isEmpty || res.get != SimpleSHA1(authKey+challenge)) {
readerWriter.writeMessage("invalid key");
} else {
readerWriter.writeMessage("ok");
while (running.get()) {
val q = readerWriter.readMessage;
if (q.isEmpty) {
running.set(false);
} else {
q.get match {
case s: String => readerWriter.writeMessage(handler.handle(s).getOrElse(""));
case b: Array[byte] =>
readerWriter.writeMessage(handler.handle(b).getOrElse(new Array[byte](0)));
case x: AnyRef => throw new UnknownTypeException(x.getClass.getName);
}
}
}
}
}
} catch {
case e => { }
} finally {
cs.close();
}
}
def stop() {
if (running.compareAndSet(true, false)) {
thread.interrupt();
}
}
}
val ss = new ServerSocket(port);
if (host.isDefined) {
ss.bind(InetSocketAddress.createUnresolved(host.get, port));
}
var running = new AtomicBoolean(false);
var hasRun = false;
var serverThread: Thread = null;
val clients = new HashSet[SarsServerHandler];
var daemon = false;
val server = this;
def start() {
if (hasRun)
throw new RuntimeException("Can't reuse server.");
hasRun = true;
if (running.compareAndSet(false, true)) {
serverThread = new Thread() {
override def run() {
while(running.get()) {
val cs = try {
ss.accept();
} catch {
case e: SocketException => {
if (running.get()) {
println("socket exception.");
e.printStackTrace();
if (! ss.isClosed) {
ss.close();
}
return;
} else { // was closed by user.
return;
}
}
case e: IOException => {
println("i/o error");
e.printStackTrace();
ss.close();
return;
}
}
val client = new SarsServerHandler(cs);
server.synchronized {
clients += client;
}
(new Thread(client)).start();
}
}
}
if (daemon)
serverThread.setDaemon(true);
serverThread.start();
} else {
throw new RuntimeException("WTF, fool? Server's running already.");
}
}
def stop() {
if (running.compareAndSet(true, false)) {
if (! ss.isClosed) {
ss.close();
}
server.synchronized {
for (client <- clients) {
client.stop();
}
}
} else {
throw new RuntimeException("Not running.");
}
}
def join() {
serverThread.join();
}
}
object test {
def main(args: Array[String]) {
val handler = new SarsMessageHandler {
override def handle(s: String) = {
println("SERVER: "+s);
if (s == "hello!") {
Some("hey there.");
} else {
None;
}
}
override def handle(b: Array[byte]) = {
var actually = new String(b, "UTF-8");
println("SERVER: "+actually);
if (actually == "hello!") {
Some("hey there.".getBytes("UTF-8"));
} else {
None;
}
}
}
val server = new SarsServer("nopassword", handler, None, 9001);
server.start();
val client = new SarsClient("nopassword", "localhost", 9001);
client.connect();
println("CLIENT: "+client.message("hello!"));
println("CLIENT: "+client.message("goodbye!"));
println("CLIENT: "+new String(client.message("hello!".getBytes("UTF-8")), "UTF-8"));
println("CLIENT: "+new String(client.message("goodbye!".getBytes("UTF-8")), "UTF-8"));
client.close();
server.stop();
server.join();
println("done.");
}
}