inial commit
This commit is contained in:
		
						commit
						62ea8ab3e3
					
				
							
								
								
									
										123
									
								
								llmai.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										123
									
								
								llmai.cpp
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,123 @@ | |||
| #include "llmai.h" | ||||
| 
 | ||||
| #include <iostream> | ||||
| #include <cassert> | ||||
| 
 | ||||
| std::string llama_token_to_piece(const llama_model* model, llama_token token) | ||||
| { | ||||
| 	std::vector<char> result(8, 0); | ||||
| 	const int n_tokens = llama_token_to_piece(model, token, result.data(), result.size()); | ||||
| 	if (n_tokens < 0) | ||||
| 	{ | ||||
| 		result.resize(-n_tokens); | ||||
| 		int check = llama_token_to_piece(model, token, result.data(), result.size()); | ||||
| 		assert(check == -n_tokens); | ||||
| 	} else | ||||
| 	{ | ||||
| 		result.resize(n_tokens); | ||||
| 	} | ||||
| 
 | ||||
| 	return std::string(result.data(), result.size()); | ||||
| } | ||||
| 
 | ||||
| std::string llama_untokenize(const std::vector<llama_token>& tokens, const llama_model* model) | ||||
| { | ||||
| 	std::string out; | ||||
| 	out.reserve(tokens.size()*2); | ||||
| 	for(const llama_token token : tokens) | ||||
| 		out.append(llama_token_to_piece(model, token)); | ||||
| 	return out; | ||||
| } | ||||
| 
 | ||||
| static void llama_batch_add(llama_batch& batch, llama_token id, llama_pos pos, const std::vector<llama_seq_id> &seq_ids, bool logits) | ||||
| { | ||||
| 	batch.token   [batch.n_tokens] = id; | ||||
| 	batch.pos     [batch.n_tokens] = pos; | ||||
| 	batch.n_seq_id[batch.n_tokens] = seq_ids.size(); | ||||
| 	for (size_t i = 0; i < seq_ids.size(); ++i) | ||||
| 		batch.seq_id[batch.n_tokens][i] = seq_ids[i]; | ||||
| 	batch.logits  [batch.n_tokens] = logits; | ||||
| 
 | ||||
| 	batch.n_tokens++; | ||||
| } | ||||
| 
 | ||||
| std::vector<llama_token> generate_text(std::vector<llama_token> prompt, llama_model* model, int batch_size) | ||||
| { | ||||
| 	std::vector<llama_token> out; | ||||
| 	llama_context_params ctx_params = llama_context_default_params(); | ||||
| 	ctx_params.n_ctx = 2048; | ||||
| 	ctx_params.n_threads = 16; | ||||
| 	ctx_params.n_threads_batch = 16; | ||||
| 
 | ||||
| 	llama_context *ctx = llama_new_context_with_model(model, ctx_params); | ||||
| 	if(!ctx) | ||||
| 	{ | ||||
| 		std::cerr<<"Unable to create context\n"; | ||||
| 		return out; | ||||
|     } | ||||
| 
 | ||||
| 	llama_batch batch = llama_batch_init(1024*batch_size, 0, 1); | ||||
| 
 | ||||
| 	batch.n_tokens = 0; | ||||
| 	for (size_t i = 0; i < prompt.size(); i++) | ||||
| 		llama_batch_add(batch, prompt[i], i, { 0 }, false); | ||||
| 	batch.logits[batch.n_tokens - 1] = true; | ||||
| 
 | ||||
| 	int32_t ret = llama_decode(ctx, batch); | ||||
| 
 | ||||
| 	if(ret != 0) | ||||
| 	{ | ||||
| 		std::cout<<"WARNING: unable to decode context\n"; | ||||
| 		return out; | ||||
| 	} | ||||
| 
 | ||||
| 	int n_cur    = batch.n_tokens; | ||||
| 	int n_decode = 0; | ||||
| 	while (n_cur <= 1024) | ||||
| 	{ | ||||
| 		{ | ||||
| 			auto   n_vocab = llama_n_vocab(model); | ||||
| 			auto * logits  = llama_get_logits_ith(ctx, batch.n_tokens - 1); | ||||
| 
 | ||||
| 			std::vector<llama_token_data> candidates; | ||||
| 			candidates.reserve(n_vocab); | ||||
| 
 | ||||
| 			for (llama_token token_id = 0; token_id < n_vocab; token_id++) { | ||||
| 				candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f }); | ||||
| 			} | ||||
| 
 | ||||
| 			llama_token_data_array candidates_p = {candidates.data(), candidates.size(), false}; | ||||
| 
 | ||||
| 			// sample the most likely token
 | ||||
| 			const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p); | ||||
| 
 | ||||
| 			// is it an end of stream?
 | ||||
| 			if (new_token_id == llama_token_eos(model)) | ||||
| 				break; | ||||
| 
 | ||||
| 			out.push_back(new_token_id); | ||||
| 
 | ||||
| 			// prepare the next batch
 | ||||
| 			batch.n_tokens = 0; | ||||
| 
 | ||||
| 			// push this new token for next evaluation
 | ||||
| 			llama_batch_add(batch, new_token_id, n_cur, { 0 }, true); | ||||
| 
 | ||||
| 			n_decode += 1; | ||||
| 		} | ||||
| 
 | ||||
| 		n_cur += 1; | ||||
| 
 | ||||
| 		ret = llama_decode(ctx, batch); | ||||
| 		if(ret != 0) | ||||
| 		{ | ||||
| 			std::cout<<"WARNING: unable to decode context\n"; | ||||
| 			break; | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	llama_batch_free(batch); | ||||
|     llama_free(ctx); | ||||
| 
 | ||||
| 	return out; | ||||
| } | ||||
		Loading…
	
		Reference in a new issue