package com.heaven.terminal
import android.content.Context import java.io.FileInputStream import java.nio.ByteBuffer import java.nio.ByteOrder import java.nio.channels.FileChannel import org.tensorflow.lite.Interpreter
/**
Handles loading and running an on-device TensorFlow Lite model for text generation.
This class is designed as a singleton ('object') because model loading is expensive
and should only be done once. */ object ModelLoader {
private const val MODEL_PATH = "ai_model.tflite" private const val VOCAB_PATH = "vocab.txt" private const val MAX_SEQ_LEN = 128 // Max length of input/output sequence for the model
private var interpreter: Interpreter? = null private lateinit var vocabulary: Map<String, Int> private lateinit var reversedVocabulary: Map<Int, String>
/**
Ensures the TFLite interpreter and vocabulary are loaded before use.
This function is synchronized to prevent race conditions from multiple threads.
@param context The application context, needed to access assets. */ @Synchronized private fun initialize(context: Context) { if (interpreter != null) return
// Load the TFLite model from assets val model = loadModelFile(context, MODEL_PATH) interpreter = Interpreter(model)
// Load the vocabulary for tokenizing text loadVocabulary(context, VOCAB_PATH)
}
/**
Generates a text suggestion based on the input prompt.
@param context The application context.
@param prompt The input text to feed the model (e.g., a bash command).
@return The AI-generated suggestion as a String. */ fun generate(context: Context, prompt: String): String { initialize(context)
// 1. Tokenize the input prompt val inputTokens = tokenize(prompt)
// 2. Prepare model inputs and outputs val inputBuffer = ByteBuffer.allocateDirect(4 * MAX_SEQ_LEN).apply { order(ByteOrder.nativeOrder()) for (token in inputTokens) { putInt(token) } }
val outputBuffer = ByteBuffer.allocateDirect(4 * MAX_SEQ_LEN).apply { order(ByteOrder.nativeOrder()) }
// 3. Run inference try { interpreter?.run(inputBuffer, outputBuffer) } catch (e: Exception) { return "Model inference failed: ${e.message}" }
// 4. Detokenize the output to get a human-readable string outputBuffer.rewind() val outputTokens = IntArray(MAX_SEQ_LEN) { outputBuffer.getInt() }
return detokenize(outputTokens)
}
private fun loadModelFile(context: Context, path: String): ByteBuffer { val fileDescriptor = context.assets.openFd(path) val inputStream = FileInputStream(fileDescriptor.fileDescriptor) val fileChannel = inputStream.channel val startOffset = fileDescriptor.startOffset val declaredLength = fileDescriptor.declaredLength return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength) }
private fun loadVocabulary(context: Context, path: String) { val vocabList = context.assets.open(path).bufferedReader().useLines { it.toList() } vocabulary = vocabList.mapIndexed { index, word -> word to index }.toMap() reversedVocabulary = vocabulary.entries.associate { (k, v) -> v to k } }
private fun tokenize(text: String): IntArray { val tokens = text.lowercase().split(Regex("\s+")).mapNotNull { vocabulary[it] ?: vocabulary[""] // Use for unknown words }.toMutableList()
// Pad the sequence to the required length while (tokens.size < MAX_SEQ_LEN) { tokens.add(vocabulary["<PAD>"] ?: 0) } return tokens.take(MAX_SEQ_LEN).toIntArray()}
private fun detokenize(tokens: IntArray): String { val words = tokens.mapNotNull { reversedVocabulary[it] } val endToken = "" // End of sequence token
return words.takeWhile { it != endToken && it != "<PAD>" } .joinToString(" ") .replace(" <UNK>", "") // Clean up unknown tokens}
}
Model tree for the-drifter23/LORIEN
Base model
WhiteRabbitNeo/WhiteRabbitNeo-13B-v1